diff --git a/README.md b/README.md
index e38a5436a720dc3679ac8beba43c40e6c72aabd7..a7b2fd501a948e6cf9369325d433a66c032c316c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,14 @@
---
-title: LAM
-emoji: 🌍
-colorFrom: green
-colorTo: pink
+title: LAM_test
+emoji: ⚡
+colorFrom: red
+colorTo: indigo
sdk: gradio
-sdk_version: 5.23.3
+sdk_version: 5.20.1
app_file: app.py
pinned: false
+license: apache-2.0
+short_description: Large Avatar Model for One-shot Animatable Gaussian Head
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..766395caa6b5586cbd628f430ad5626189244e55
--- /dev/null
+++ b/app.py
@@ -0,0 +1,568 @@
+# Copyright (c) 2024-2025, Yisheng He, Yuan Dong
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+os.system("rm -rf /data-nvme/zerogpu-offload/")
+os.system("pip install chumpy")
+# os.system("pip uninstall -y basicsr")
+os.system("pip install Cython")
+os.system("pip install ./wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")
+os.system("pip install ./wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")
+os.system("pip install ./wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl --force-reinstall")
+os.system(
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html")
+os.system("pip install numpy==1.23.0")
+
+import cv2
+import sys
+import base64
+import subprocess
+
+import argparse
+from glob import glob
+import gradio as gr
+import numpy as np
+from PIL import Image
+from omegaconf import OmegaConf
+
+import torch
+import moviepy.editor as mpy
+from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
+from lam.utils.ffmpeg_utils import images_to_video
+
+import spaces
+
+
+def compile_module(subfolder, script):
+ try:
+ # Save the current working directory
+ current_dir = os.getcwd()
+ # Change directory to the subfolder
+ os.chdir(os.path.join(current_dir, subfolder))
+ # Run the compilation command
+ result = subprocess.run(
+ ["sh", script],
+ capture_output=True,
+ text=True,
+ check=True
+ )
+ # Print the compilation output
+ print("Compilation output:", result.stdout)
+
+ except Exception as e:
+ # Print any error that occurred
+ print(f"An error occurred: {e}")
+ finally:
+ # Ensure returning to the original directory
+ os.chdir(current_dir)
+ print("Returned to the original directory.")
+
+
+# compile flame_tracking dependence submodule
+compile_module("external/landmark_detection/FaceBoxesV2/utils/", "make.sh")
+from flame_tracking_single_image import FlameTrackingSingleImage
+
+
+def launch_pretrained():
+ from huggingface_hub import snapshot_download, hf_hub_download
+ # launch pretrained for flame tracking.
+ hf_hub_download(repo_id='yuandong513/flametracking_model',
+ repo_type='model',
+ filename='pretrain_model.tar',
+ local_dir='./')
+ os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar')
+ # launch human model files
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
+ repo_type='model',
+ filename='LAM_human_model.tar',
+ local_dir='./')
+ os.system('tar -xf LAM_human_model.tar && rm LAM_human_model.tar')
+ # launch pretrained for LAM
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="config.json")
+ print(model_dir)
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="model.safetensors")
+ print(model_dir)
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="README.md")
+ print(model_dir)
+ # launch example for LAM
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
+ repo_type='model',
+ filename='LAM_assets.tar',
+ local_dir='./')
+ os.system('tar -xf LAM_assets.tar && rm LAM_assets.tar')
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
+ repo_type='model',
+ filename='config.json',
+ local_dir='./tmp/')
+
+
+def launch_env_not_compile_with_cuda():
+ os.system('pip install chumpy')
+ os.system('pip install numpy==1.23.0')
+ os.system(
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
+ )
+
+
+def assert_input_image(input_image):
+ if input_image is None:
+ raise gr.Error('No image selected or uploaded!')
+
+
+def prepare_working_dir():
+ import tempfile
+ working_dir = tempfile.TemporaryDirectory()
+ return working_dir
+
+
+def init_preprocessor():
+ from lam.utils.preprocess import Preprocessor
+ global preprocessor
+ preprocessor = Preprocessor()
+
+
+def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
+ working_dir):
+ image_raw = os.path.join(working_dir.name, 'raw.png')
+ with Image.fromarray(image_in) as img:
+ img.save(image_raw)
+ image_out = os.path.join(working_dir.name, 'rembg.png')
+ success = preprocessor.preprocess(image_path=image_raw,
+ save_path=image_out,
+ rmbg=remove_bg,
+ recenter=recenter)
+ assert success, f'Failed under preprocess_fn!'
+ return image_out
+
+
+def get_image_base64(path):
+ with open(path, 'rb') as image_file:
+ encoded_string = base64.b64encode(image_file.read()).decode()
+ return f'data:image/png;base64,{encoded_string}'
+
+
+def save_imgs_2_video(imgs, v_pth, fps=30):
+ # moviepy example
+ from moviepy.editor import ImageSequenceClip, VideoFileClip
+ images = [image.astype(np.uint8) for image in imgs]
+ clip = ImageSequenceClip(images, fps=fps)
+ # final_duration = len(images) / fps
+ # clip = clip.subclip(0, final_duration)
+ clip = clip.subclip(0, len(images) / fps)
+ clip.write_videofile(v_pth, codec='libx264')
+
+ import cv2
+ cap = cv2.VideoCapture(v_pth)
+ nf = cap.get(cv2.CAP_PROP_FRAME_COUNT)
+ if nf != len(images):
+ print("="*100+f"\n{v_pth} moviepy saved video frame error."+"\n"+"="*100)
+ print(f"Video saved successfully at {v_pth}")
+
+
+def add_audio_to_video(video_path, out_path, audio_path, fps=30):
+ # Import necessary modules from moviepy
+ from moviepy.editor import VideoFileClip, AudioFileClip
+
+ # Load video file into VideoFileClip object
+ video_clip = VideoFileClip(video_path)
+
+ # Load audio file into AudioFileClip object
+ audio_clip = AudioFileClip(audio_path)
+
+ # Hard code clip audio
+ if audio_clip.duration > 10:
+ audio_clip = audio_clip.subclip(0, 10)
+
+ # Attach audio clip to video clip (replaces existing audio)
+ video_clip_with_audio = video_clip.set_audio(audio_clip)
+
+ # Export final video with audio using standard codecs
+ video_clip_with_audio.write_videofile(out_path, codec='libx264', audio_codec='aac', fps=fps)
+
+ print(f"Audio added successfully at {out_path}")
+
+
+def parse_configs():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--infer", type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get("APP_INFER") is not None:
+ args.infer = os.environ.get("APP_INFER")
+ if os.environ.get("APP_MODEL_NAME") is not None:
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
+
+ args.config = args.infer if args.config is None else args.config
+
+ if args.config is not None:
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ try:
+ cfg.src_head_size = cfg_train.dataset.src_head_size
+ except:
+ cfg.src_head_size = 112
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(
+ cfg_train.experiment.parent,
+ cfg_train.experiment.child,
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
+ )
+
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ cfg.setdefault(
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
+ )
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
+ cfg.setdefault(
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
+ )
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
+
+ cfg.motion_video_read_fps = 30
+ cfg.merge_with(cli_cfg)
+
+ cfg.setdefault("logger", "INFO")
+
+ assert cfg.model_name is not None, "model_name is required"
+
+ return cfg, cfg_train
+
+
+def demo_lam(flametracking, lam, cfg):
+ @spaces.GPU(duration=80)
+ def core_fn(image_path: str, video_params, working_dir):
+ image_raw = os.path.join(working_dir.name, "raw.png")
+ with Image.open(image_path).convert('RGB') as img:
+ img.save(image_raw)
+
+ base_vid = os.path.basename(video_params).split(".")[0]
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
+ base_iid = os.path.basename(image_path).split('.')[0]
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
+
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
+ dump_image_path = os.path.join(working_dir.name, "output.png")
+
+ # prepare dump paths
+ omit_prefix = os.path.dirname(image_raw)
+ image_name = os.path.basename(image_raw)
+ uid = image_name.split(".")[0]
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
+ subdir_path = (
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
+ )
+ print("subdir_path and uid:", subdir_path, uid)
+
+ motion_seqs_dir = flame_params_dir
+
+ dump_image_dir = os.path.dirname(dump_image_path)
+ os.makedirs(dump_image_dir, exist_ok=True)
+
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
+
+ dump_tmp_dir = dump_image_dir
+
+ if os.path.exists(dump_video_path):
+ return dump_image_path, dump_video_path
+
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
+ vis_motion = cfg.get("vis_motion", False) # False
+
+ # preprocess input image: segmentation, flame params estimation
+ # """
+ return_code = flametracking.preprocess(image_raw)
+ assert (return_code == 0), "flametracking preprocess failed!"
+ return_code = flametracking.optimize()
+ assert (return_code == 0), "flametracking optimize failed!"
+ return_code, output_dir = flametracking.export()
+ assert (return_code == 0), "flametracking export failed!"
+ image_path = os.path.join(output_dir, "images/00000_00.png")
+ # """
+
+ mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png")
+ print(image_path, mask_path)
+
+ aspect_standard = 1.0 / 1.0
+ source_size = cfg.source_size
+ render_size = cfg.render_size
+ render_fps = 30
+ # prepare reference image
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0,
+ bg_color=1.,
+ max_tgt_size=None, aspect_standard=aspect_standard,
+ enlarge_ratio=[1.0, 1.0],
+ render_tgt_size=source_size, multiply=14, need_mask=True,
+ get_shape_param=True)
+
+ # save masked image for vis
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
+
+ # prepare motion seq
+ src = image_path.split('/')[-3]
+ driven = motion_seqs_dir.split('/')[-2]
+ src_driven = [src, driven]
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1, 0],
+ render_image_res=render_size, multiply=16,
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
+ shape_param=shape_param, test_sample=False, cross_id=False,
+ src_driven=src_driven, max_squen_length=300)
+
+ # start inference
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
+ device, dtype = "cuda", torch.float32
+ print("start to inference...................")
+ with torch.no_grad():
+ # TODO check device and dtype
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
+ render_c2ws=motion_seq["render_c2ws"].to(device),
+ render_intrs=motion_seq["render_intrs"].to(device),
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
+ flame_params={k: v.to(device) for k, v in motion_seq["flame_params"].items()})
+
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
+ mask[mask < 0.5] = 0.0
+ rgb = rgb * mask + (1 - mask) * 1
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
+ if vis_motion:
+ vis_ref_img = np.tile(
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :,
+ :],
+ (rgb.shape[0], 1, 1, 1),
+ )
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
+
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
+
+ print("==="*36, "\nrgb length:", rgb.shape, render_fps, "==="*36)
+ save_imgs_2_video(rgb, dump_video_path, render_fps)
+ # images_to_video(rgb, output_path=dump_video_path, fps=30, gradio_codec=False, verbose=True)
+ audio_path = os.path.join("./assets/sample_motion/export", base_vid, base_vid + ".wav")
+ dump_video_path_wa = dump_video_path.replace(".mp4", "_audio.mp4")
+ add_audio_to_video(dump_video_path, dump_video_path_wa, audio_path)
+
+ return dump_image_path, dump_video_path_wa
+
+ def core_fn_space(image_path: str, video_params, working_dir):
+ return core_fn(image_path, video_params, working_dir)
+
+ with gr.Blocks(analytics_enabled=False) as demo:
+
+ logo_url = './assets/images/logo.jpeg'
+ logo_base64 = get_image_base64(logo_url)
+ gr.HTML(f"""
+
+
+
Large Avatar Model for One-shot Animatable Gaussian Head
+
+
+ """)
+
+ gr.HTML(
+ """
+
+ """
+ )
+
+
+ gr.HTML("""
+
Notes1: Inputing front-face images or face orientation close to the driven signal gets better results.
+
Notes2: Due to computational constraints with Hugging Face's ZeroGPU infrastructure, video generation requires ~1 minute per instance.
+
Notes3: Using LAM-20K model (lower quality than premium LAM-80K) to mitigate processing latency.
+
""")
+
+
+
+
+ # DISPLAY
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_input_image'):
+ with gr.TabItem('Input Image'):
+ with gr.Row():
+ input_image = gr.Image(label='Input Image',
+ image_mode='RGB',
+ height=480,
+ width=270,
+ sources='upload',
+ type='filepath',
+ elem_id='content_image')
+ # EXAMPLES
+ with gr.Row():
+ examples = [
+ ['assets/sample_input/messi.png'],
+ ['assets/sample_input/status.png'],
+ ['assets/sample_input/james.png'],
+ ['assets/sample_input/cluo.jpg'],
+ ['assets/sample_input/dufu.jpg'],
+ ['assets/sample_input/libai.jpg'],
+ ['assets/sample_input/barbara.jpg'],
+ ['assets/sample_input/pop.png'],
+ ['assets/sample_input/musk.jpg'],
+ ['assets/sample_input/speed.jpg'],
+ ['assets/sample_input/zhouxingchi.jpg'],
+ ]
+ gr.Examples(
+ examples=examples,
+ inputs=[input_image],
+ examples_per_page=20
+ )
+
+
+ with gr.Column():
+ with gr.Tabs(elem_id='lam_input_video'):
+ with gr.TabItem('Input Video'):
+ with gr.Row():
+ video_input = gr.Video(label='Input Video',
+ height=480,
+ width=270,
+ interactive=False)
+
+ examples = ['./assets/sample_motion/export/Speeding_Scandal/Speeding_Scandal.mp4',
+ './assets/sample_motion/export/Look_In_My_Eyes/Look_In_My_Eyes.mp4',
+ './assets/sample_motion/export/D_ANgelo_Dinero/D_ANgelo_Dinero.mp4',
+ './assets/sample_motion/export/Michael_Wayne_Rosen/Michael_Wayne_Rosen.mp4',
+ './assets/sample_motion/export/I_Am_Iron_Man/I_Am_Iron_Man.mp4',
+ './assets/sample_motion/export/Anti_Drugs/Anti_Drugs.mp4',
+ './assets/sample_motion/export/Pen_Pineapple_Apple_Pen/Pen_Pineapple_Apple_Pen.mp4',
+ './assets/sample_motion/export/Joe_Biden/Joe_Biden.mp4',
+ './assets/sample_motion/export/Donald_Trump/Donald_Trump.mp4',
+ './assets/sample_motion/export/Taylor_Swift/Taylor_Swift.mp4',
+ './assets/sample_motion/export/GEM/GEM.mp4',
+ './assets/sample_motion/export/The_Shawshank_Redemption/The_Shawshank_Redemption.mp4'
+ ]
+ print("Video example list {}".format(examples))
+
+ gr.Examples(
+ examples=examples,
+ inputs=[video_input],
+ examples_per_page=20,
+ )
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_processed_image'):
+ with gr.TabItem('Processed Image'):
+ with gr.Row():
+ processed_image = gr.Image(
+ label='Processed Image',
+ image_mode='RGBA',
+ type='filepath',
+ elem_id='processed_image',
+ height=480,
+ width=270,
+ interactive=False)
+
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_render_video'):
+ with gr.TabItem('Rendered Video'):
+ with gr.Row():
+ output_video = gr.Video(label='Rendered Video',
+ format='mp4',
+ height=480,
+ width=270,
+ autoplay=True)
+
+ # SETTING
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ submit = gr.Button('Generate',
+ elem_id='lam_generate',
+ variant='primary')
+
+ main_fn = core_fn
+
+ working_dir = gr.State()
+ submit.click(
+ fn=assert_input_image,
+ inputs=[input_image],
+ queue=False,
+ ).success(
+ fn=prepare_working_dir,
+ outputs=[working_dir],
+ queue=False,
+ ).success(
+ fn=main_fn,
+ inputs=[input_image, video_input,
+ working_dir], # video_params refer to smpl dir
+ outputs=[processed_image, output_video],
+ )
+
+ demo.queue()
+ demo.launch()
+
+
+def _build_model(cfg):
+ from lam.models import model_dict
+ from lam.utils.hf_hub import wrap_model_hub
+
+ hf_model_cls = wrap_model_hub(model_dict["lam"])
+ model = hf_model_cls.from_pretrained(cfg.model_name)
+
+ return model
+
+
+def launch_gradio_app():
+ os.environ.update({
+ 'APP_ENABLED': '1',
+ 'APP_MODEL_NAME':
+ './exps/releases/lam/lam-20k/step_045500/',
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
+ 'APP_TYPE': 'infer.lam',
+ 'NUMBA_THREADING_LAYER': 'omp',
+ })
+
+ cfg, _ = parse_configs()
+ lam = _build_model(cfg)
+ lam.to('cuda')
+
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
+ detect_iris_landmarks=False)
+
+ demo_lam(flametracking, lam, cfg)
+
+
+if __name__ == '__main__':
+ launch_pretrained()
+ launch_gradio_app()
diff --git a/app_lam.py b/app_lam.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cdf73ee74b7fea2001f1b087d90f3d7566e46ca
--- /dev/null
+++ b/app_lam.py
@@ -0,0 +1,433 @@
+# Copyright (c) 2024-2025, Yisheng He, Yuan Dong
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import base64
+import subprocess
+
+import gradio as gr
+import numpy as np
+from PIL import Image
+import argparse
+from omegaconf import OmegaConf
+
+import torch
+from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
+import moviepy.editor as mpy
+from lam.utils.ffmpeg_utils import images_to_video
+import sys
+from flame_tracking_single_image import FlameTrackingSingleImage
+
+try:
+ import spaces
+except:
+ pass
+
+
+def launch_pretrained():
+ from huggingface_hub import snapshot_download, hf_hub_download
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
+ repo_type='model',
+ filename='assets.tar',
+ local_dir='./')
+ os.system('tar -xvf assets.tar && rm assets.tar')
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
+ repo_type='model',
+ filename='LHM-0.5B.tar',
+ local_dir='./')
+ os.system('tar -xvf LHM-0.5B.tar && rm LHM-0.5B.tar')
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
+ repo_type='model',
+ filename='LHM_prior_model.tar',
+ local_dir='./')
+ os.system('tar -xvf LHM_prior_model.tar && rm LHM_prior_model.tar')
+
+
+def launch_env_not_compile_with_cuda():
+ os.system('pip install chumpy')
+ os.system('pip uninstall -y basicsr')
+ os.system('pip install git+https://github.com/hitsz-zuoqi/BasicSR/')
+ os.system('pip install numpy==1.23.0')
+ os.system(
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
+ )
+
+
+def assert_input_image(input_image):
+ if input_image is None:
+ raise gr.Error('No image selected or uploaded!')
+
+
+def prepare_working_dir():
+ import tempfile
+ working_dir = tempfile.TemporaryDirectory()
+ return working_dir
+
+
+def init_preprocessor():
+ from lam.utils.preprocess import Preprocessor
+ global preprocessor
+ preprocessor = Preprocessor()
+
+
+def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
+ working_dir):
+ image_raw = os.path.join(working_dir.name, 'raw.png')
+ with Image.fromarray(image_in) as img:
+ img.save(image_raw)
+ image_out = os.path.join(working_dir.name, 'rembg.png')
+ success = preprocessor.preprocess(image_path=image_raw,
+ save_path=image_out,
+ rmbg=remove_bg,
+ recenter=recenter)
+ assert success, f'Failed under preprocess_fn!'
+ return image_out
+
+
+def get_image_base64(path):
+ with open(path, 'rb') as image_file:
+ encoded_string = base64.b64encode(image_file.read()).decode()
+ return f'data:image/png;base64,{encoded_string}'
+
+
+def save_imgs_2_video(imgs, v_pth, fps):
+ img_lst = [imgs[i] for i in range(imgs.shape[0])]
+ # Convert the list of NumPy arrays to a list of ImageClip objects
+ clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame
+
+ # Concatenate the ImageClips into a single VideoClip
+ video = mpy.concatenate_videoclips(clips, method="compose")
+
+ # Write the VideoClip to a file
+ video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example
+
+
+def parse_configs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--infer", type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get("APP_INFER") is not None:
+ args.infer = os.environ.get("APP_INFER")
+ if os.environ.get("APP_MODEL_NAME") is not None:
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
+
+ args.config = args.infer if args.config is None else args.config
+
+ if args.config is not None:
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ try:
+ cfg.src_head_size = cfg_train.dataset.src_head_size
+ except:
+ cfg.src_head_size = 112
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(
+ cfg_train.experiment.parent,
+ cfg_train.experiment.child,
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
+ )
+
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ cfg.setdefault(
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
+ )
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
+ cfg.setdefault(
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
+ )
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
+
+ cfg.motion_video_read_fps = 6
+ cfg.merge_with(cli_cfg)
+
+ cfg.setdefault("logger", "INFO")
+
+ assert cfg.model_name is not None, "model_name is required"
+
+ return cfg, cfg_train
+
+
+def demo_lam(flametracking, lam, cfg):
+
+ # @spaces.GPU(duration=80)
+ def core_fn(image_path: str, video_params, working_dir):
+ image_raw = os.path.join(working_dir.name, "raw.png")
+ with Image.open(image_path).convert('RGB') as img:
+ img.save(image_raw)
+
+ base_vid = os.path.basename(video_params).split(".")[0]
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
+ base_iid = os.path.basename(image_path).split('.')[0]
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
+
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
+ dump_image_path = os.path.join(working_dir.name, "output.png")
+
+ # prepare dump paths
+ omit_prefix = os.path.dirname(image_raw)
+ image_name = os.path.basename(image_raw)
+ uid = image_name.split(".")[0]
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
+ subdir_path = (
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
+ )
+ print("subdir_path and uid:", subdir_path, uid)
+
+ motion_seqs_dir = flame_params_dir
+
+ dump_image_dir = os.path.dirname(dump_image_path)
+ os.makedirs(dump_image_dir, exist_ok=True)
+
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
+
+ dump_tmp_dir = dump_image_dir
+
+ if os.path.exists(dump_video_path):
+ return dump_image_path, dump_video_path
+
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
+ vis_motion = cfg.get("vis_motion", False) # False
+
+ # preprocess input image: segmentation, flame params estimation
+ return_code = flametracking.preprocess(image_raw)
+ assert (return_code == 0), "flametracking preprocess failed!"
+ return_code = flametracking.optimize()
+ assert (return_code == 0), "flametracking optimize failed!"
+ return_code, output_dir = flametracking.export()
+ assert (return_code == 0), "flametracking export failed!"
+
+ image_path = os.path.join(output_dir, "images/00000_00.png")
+ mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png")
+ print(image_path, mask_path)
+
+ aspect_standard = 1.0/1.0
+ source_size = cfg.source_size
+ render_size = cfg.render_size
+ render_fps = 30
+ # prepare reference image
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1.,
+ max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0],
+ render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True)
+
+ # save masked image for vis
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
+
+ # prepare motion seq
+ src = image_path.split('/')[-3]
+ driven = motion_seqs_dir.split('/')[-2]
+ src_driven = [src, driven]
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
+ render_image_res=render_size, multiply=16,
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
+ shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven)
+
+ # start inference
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
+ device, dtype = "cuda", torch.float32
+ print("start to inference...................")
+ with torch.no_grad():
+ # TODO check device and dtype
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
+ render_c2ws=motion_seq["render_c2ws"].to(device),
+ render_intrs=motion_seq["render_intrs"].to(device),
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
+ flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()})
+
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
+ mask[mask < 0.5] = 0.0
+ rgb = rgb * mask + (1 - mask) * 1
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
+ if vis_motion:
+ vis_ref_img = np.tile(
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :],
+ (rgb.shape[0], 1, 1, 1),
+ )
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
+
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
+
+ save_imgs_2_video(rgb, dump_video_path, render_fps)
+ # images_to_video(rgb, output_path=dump_video_path, fps=30, gradio_codec=False, verbose=True)
+
+ return dump_image_path, dump_video_path
+
+ with gr.Blocks(analytics_enabled=False) as demo:
+
+ logo_url = './assets/images/logo.png'
+ logo_base64 = get_image_base64(logo_url)
+ gr.HTML(f"""
+
+
+
LAM: Large Avatar Model for One-shot Animatable Gaussian Head
+
+
+ """)
+ gr.HTML(
+ """
Notes: Inputing front-face images or face orientation close to the driven signal gets better results. """
+ )
+
+ # DISPLAY
+ with gr.Row():
+
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_input_image'):
+ with gr.TabItem('Input Image'):
+ with gr.Row():
+ input_image = gr.Image(label='Input Image',
+ image_mode='RGB',
+ height=480,
+ width=270,
+ sources='upload',
+ type='filepath', # 'numpy',
+ elem_id='content_image')
+ # EXAMPLES
+ with gr.Row():
+ examples = [
+ ['assets/sample_input/2w01/images/2w01.png'],
+ ['assets/sample_input/2w02/images/2w02.png'],
+ ['assets/sample_input/2w03/images/2w03.png'],
+ ['assets/sample_input/2w04/images/2w04.png'],
+ ]
+ gr.Examples(
+ examples=examples,
+ inputs=[input_image],
+ examples_per_page=20,
+ )
+
+ with gr.Column():
+ with gr.Tabs(elem_id='lam_input_video'):
+ with gr.TabItem('Input Video'):
+ with gr.Row():
+ video_input = gr.Video(label='Input Video',
+ height=480,
+ width=270,
+ interactive=False)
+
+ examples = [
+ './assets/sample_motion/export/clip1/clip1.mp4',
+ './assets/sample_motion/export/clip2/clip2.mp4',
+ './assets/sample_motion/export/clip3/clip3.mp4',
+ ]
+
+ gr.Examples(
+ examples=examples,
+ inputs=[video_input],
+ examples_per_page=20,
+ )
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_processed_image'):
+ with gr.TabItem('Processed Image'):
+ with gr.Row():
+ processed_image = gr.Image(
+ label='Processed Image',
+ image_mode='RGBA',
+ type='filepath',
+ elem_id='processed_image',
+ height=480,
+ width=270,
+ interactive=False)
+
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_render_video'):
+ with gr.TabItem('Rendered Video'):
+ with gr.Row():
+ output_video = gr.Video(label='Rendered Video',
+ format='mp4',
+ height=480,
+ width=270,
+ autoplay=True)
+
+ # SETTING
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ submit = gr.Button('Generate',
+ elem_id='lam_generate',
+ variant='primary')
+
+ working_dir = gr.State()
+ submit.click(
+ fn=assert_input_image,
+ inputs=[input_image],
+ queue=False,
+ ).success(
+ fn=prepare_working_dir,
+ outputs=[working_dir],
+ queue=False,
+ ).success(
+ fn=core_fn,
+ inputs=[input_image, video_input,
+ working_dir], # video_params refer to smpl dir
+ outputs=[processed_image, output_video],
+ )
+
+ demo.queue()
+ demo.launch()
+
+
+def _build_model(cfg):
+ from lam.models import model_dict
+ from lam.utils.hf_hub import wrap_model_hub
+
+ hf_model_cls = wrap_model_hub(model_dict["lam"])
+ model = hf_model_cls.from_pretrained(cfg.model_name)
+
+ return model
+
+def launch_gradio_app():
+
+ os.environ.update({
+ 'APP_ENABLED': '1',
+ 'APP_MODEL_NAME':
+ './exps/releases/lam/lam-20k/step_045500/',
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
+ 'APP_TYPE': 'infer.lam',
+ 'NUMBA_THREADING_LAYER': 'omp',
+ })
+
+ cfg, _ = parse_configs()
+ lam = _build_model(cfg)
+ lam.to('cuda')
+
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
+ detect_iris_landmarks=True)
+
+ demo_lam(flametracking, lam, cfg)
+
+
+if __name__ == '__main__':
+ # launch_pretrained()
+ # launch_env_not_compile_with_cuda()
+ launch_gradio_app()
diff --git a/app_preprocess.py b/app_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..511c68a5746d862327243c1d6fa33c61702c7b15
--- /dev/null
+++ b/app_preprocess.py
@@ -0,0 +1,387 @@
+# Copyright (c) 2023-2024, Qi Zuo
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+os.system('rm -rf /data-nvme/zerogpu-offload/')
+os.system('pip install numpy==1.23.0')
+os.system('pip install ./wheels/pytorch3d-0.7.3-cp310-cp310-linux_x86_64.whl')
+
+import argparse
+import base64
+import time
+
+import cv2
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+from PIL import Image
+
+import gradio as gr
+import spaces
+from flame_tracking_single_image import FlameTrackingSingleImage
+from ffmpeg_utils import images_to_video
+
+# torch._dynamo.config.disable = True
+
+
+def parse_configs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str)
+ parser.add_argument('--infer', type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get('APP_INFER') is not None:
+ args.infer = os.environ.get('APP_INFER')
+ if os.environ.get('APP_MODEL_NAME') is not None:
+ cli_cfg.model_name = os.environ.get('APP_MODEL_NAME')
+
+ args.config = args.infer if args.config is None else args.config
+
+ if args.config is not None:
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ try:
+ cfg.src_head_size = cfg_train.dataset.src_head_size
+ except:
+ cfg.src_head_size = 112
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(
+ cfg_train.experiment.parent,
+ cfg_train.experiment.child,
+ os.path.basename(cli_cfg.model_name).split('_')[-1],
+ )
+
+ cfg.save_tmp_dump = os.path.join('exps', 'save_tmp', _relative_path)
+ cfg.image_dump = os.path.join('exps', 'images', _relative_path)
+ cfg.video_dump = os.path.join('exps', 'videos',
+ _relative_path) # output path
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ cfg.setdefault('save_tmp_dump',
+ os.path.join('exps', cli_cfg.model_name, 'save_tmp'))
+ cfg.setdefault('image_dump',
+ os.path.join('exps', cli_cfg.model_name, 'images'))
+ cfg.setdefault('video_dump',
+ os.path.join('dumps', cli_cfg.model_name, 'videos'))
+ cfg.setdefault('mesh_dump',
+ os.path.join('dumps', cli_cfg.model_name, 'meshes'))
+
+ cfg.motion_video_read_fps = 6
+ cfg.merge_with(cli_cfg)
+
+ cfg.setdefault('logger', 'INFO')
+
+ assert cfg.model_name is not None, 'model_name is required'
+
+ return cfg, cfg_train
+
+
+
+def launch_pretrained():
+ from huggingface_hub import snapshot_download, hf_hub_download
+ hf_hub_download(repo_id='yuandong513/flametracking_model',
+ repo_type='model',
+ filename='pretrain_model.tar',
+ local_dir='./')
+ os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar')
+
+def animation_infer(renderer, gs_model_list, query_points, smplx_params,
+ render_c2ws, render_intrs, render_bg_colors):
+ '''Inference code avoid repeat forward.
+ '''
+ render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int(
+ render_intrs[0, 0, 0, 2] * 2)
+ # render target views
+ render_res_list = []
+ num_views = render_c2ws.shape[1]
+ start_time = time.time()
+
+ # render target views
+ render_res_list = []
+
+ for view_idx in range(num_views):
+ render_res = renderer.forward_animate_gs(
+ gs_model_list,
+ query_points,
+ renderer.get_single_view_smpl_data(smplx_params, view_idx),
+ render_c2ws[:, view_idx:view_idx + 1],
+ render_intrs[:, view_idx:view_idx + 1],
+ render_h,
+ render_w,
+ render_bg_colors[:, view_idx:view_idx + 1],
+ )
+ render_res_list.append(render_res)
+ print(
+ f'time elpased(animate gs model per frame):{(time.time() - start_time)/num_views}'
+ )
+
+ out = defaultdict(list)
+ for res in render_res_list:
+ for k, v in res.items():
+ if isinstance(v[0], torch.Tensor):
+ out[k].append(v.detach().cpu())
+ else:
+ out[k].append(v)
+ for k, v in out.items():
+ # print(f"out key:{k}")
+ if isinstance(v[0], torch.Tensor):
+ out[k] = torch.concat(v, dim=1)
+ if k in ['comp_rgb', 'comp_mask', 'comp_depth']:
+ out[k] = out[k][0].permute(
+ 0, 2, 3,
+ 1) # [1, Nv, 3, H, W] -> [Nv, 3, H, W] - > [Nv, H, W, 3]
+ else:
+ out[k] = v
+ return out
+
+
+def assert_input_image(input_image):
+ if input_image is None:
+ raise gr.Error('No image selected or uploaded!')
+
+
+def prepare_working_dir():
+ import tempfile
+ working_dir = tempfile.TemporaryDirectory()
+ return working_dir
+
+def get_image_base64(path):
+ with open(path, 'rb') as image_file:
+ encoded_string = base64.b64encode(image_file.read()).decode()
+ return f'data:image/png;base64,{encoded_string}'
+
+
+def demo_lhm(flametracking):
+ @spaces.GPU(duration=80)
+ def core_fn(image: str, video_params, working_dir):
+ image_raw = os.path.join(working_dir.name, 'raw.png')
+ with Image.fromarray(image) as img:
+ img.save(image_raw)
+
+ base_vid = os.path.basename(video_params).split('_')[0]
+
+ dump_video_path = os.path.join(working_dir.name, 'output.mp4')
+ dump_image_path = os.path.join(working_dir.name, 'output.png')
+
+ # prepare dump paths
+ omit_prefix = os.path.dirname(image_raw)
+ image_name = os.path.basename(image_raw)
+ uid = image_name.split('.')[0]
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, '')
+ subdir_path = (subdir_path[1:]
+ if subdir_path.startswith('/') else subdir_path)
+ print('==> subdir_path and uid:', subdir_path, uid)
+
+ dump_image_dir = os.path.dirname(dump_image_path)
+ os.makedirs(dump_image_dir, exist_ok=True)
+
+ print('==> path:', image_raw, dump_image_dir, dump_video_path)
+
+ dump_tmp_dir = dump_image_dir
+
+ return_code = flametracking.preprocess(image_raw)
+ return_code = flametracking.optimize()
+ return_code, output_dir = flametracking.export()
+
+ print("==> output_dir:", output_dir)
+
+
+ save_ref_img_path = os.path.join(dump_tmp_dir, 'output.png')
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() *
+ 255).astype(np.uint8)
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
+
+ # rendering !!!!
+ start_time = time.time()
+ batch_dict = dict()
+
+ rgb = cv2.imread(os.path.join(output_dir,'images/00000_00.png'))
+
+ for i in range(30):
+ images_to_video(
+ rgb,
+ output_path=dump_video_path,
+ fps=30,
+ gradio_codec=False,
+ verbose=True,
+ )
+
+ return dump_image_path, dump_video_path
+
+ _TITLE = '''LHM: Large Animatable Human Model'''
+
+ _DESCRIPTION = '''
+ Reconstruct a human avatar in 0.2 seconds with A100!
+ '''
+
+ with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600]) as demo:
+
+ #
+ logo_url = './asset/logo.jpeg'
+ logo_base64 = get_image_base64(logo_url)
+ gr.HTML(f"""
+
+
+
Large Animatable Human Model
+
+
+ """)
+
+ gr.HTML("""
+
+ """)
+
+ gr.HTML(
+ """
Notes: Please input full-body image in case of detection errors. We simplify the pipeline in spaces: 1) using Rembg instead of SAM2; 2) limit the output video length to 10s; For best visual quality, try the inference code on Github instead. """
+ )
+
+ # DISPLAY
+ with gr.Row():
+
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='openlrm_input_image'):
+ with gr.TabItem('Input Image'):
+ with gr.Row():
+ input_image = gr.Image(label='Input Image',
+ image_mode='RGB',
+ height=480,
+ width=270,
+ sources='upload',
+ type='numpy',
+ elem_id='content_image')
+ # EXAMPLES
+ with gr.Row():
+ examples = [
+ ['asset/sample_input/00000.png'],
+ ]
+ gr.Examples(
+ examples=examples,
+ inputs=[input_image],
+ examples_per_page=10,
+ )
+
+ with gr.Column():
+ with gr.Tabs(elem_id='openlrm_input_video'):
+ with gr.TabItem('Input Video'):
+ with gr.Row():
+ video_input = gr.Video(label='Input Video',
+ height=480,
+ width=270,
+ interactive=False)
+
+ examples = [
+ './asset/sample_input/demo.mp4',
+ ]
+
+ gr.Examples(
+ examples=examples,
+ inputs=[video_input],
+ examples_per_page=20,
+ )
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='openlrm_processed_image'):
+ with gr.TabItem('Processed Image'):
+ with gr.Row():
+ processed_image = gr.Image(
+ label='Processed Image',
+ image_mode='RGB',
+ type='filepath',
+ elem_id='processed_image',
+ height=480,
+ width=270,
+ interactive=False)
+
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='openlrm_render_video'):
+ with gr.TabItem('Rendered Video'):
+ with gr.Row():
+ output_video = gr.Video(label='Rendered Video',
+ format='mp4',
+ height=480,
+ width=270,
+ autoplay=True)
+
+ # SETTING
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ submit = gr.Button('Generate',
+ elem_id='openlrm_generate',
+ variant='primary')
+
+ working_dir = gr.State()
+ submit.click(
+ fn=assert_input_image,
+ inputs=[input_image],
+ queue=False,
+ ).success(
+ fn=prepare_working_dir,
+ outputs=[working_dir],
+ queue=False,
+ ).success(
+ fn=core_fn,
+ inputs=[input_image, video_input,
+ working_dir], # video_params refer to smpl dir
+ outputs=[processed_image, output_video],
+ )
+
+ demo.queue(max_size=1)
+ demo.launch()
+
+
+def launch_gradio_app():
+
+ os.environ.update({
+ 'APP_ENABLED': '1',
+ 'APP_MODEL_NAME':
+ './exps/releases/video_human_benchmark/human-lrm-500M/step_060000/',
+ 'APP_INFER': './configs/inference/human-lrm-500M.yaml',
+ 'APP_TYPE': 'infer.human_lrm',
+ 'NUMBA_THREADING_LAYER': 'omp',
+ })
+
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
+ detect_iris_landmarks=True)
+
+
+ demo_lhm(flametracking)
+
+
+if __name__ == '__main__':
+ launch_pretrained()
+ launch_gradio_app()
+
diff --git a/configs/inference/lam-20k-8gpu.yaml b/configs/inference/lam-20k-8gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f7d471540e5bbb0e733e3d91f402ff3fb81a1598
--- /dev/null
+++ b/configs/inference/lam-20k-8gpu.yaml
@@ -0,0 +1,130 @@
+
+experiment:
+ type: lam
+ seed: 42
+ parent: lam
+ child: lam_20k
+model:
+ # image encoder
+ encoder_type: "dinov2_fusion"
+ encoder_model_name: "dinov2_vitl14_reg"
+ encoder_feat_dim: 1024
+ encoder_freeze: false
+
+ # points embeddings
+ latent_query_points_type: "e2e_flame"
+ pcl_dim: 1024
+
+ # transformer
+ transformer_type: "sd3_cond"
+ transformer_heads: 16
+ transformer_dim: 1024
+ transformer_layers: 10
+ tf_grad_ckpt: true
+ encoder_grad_ckpt: true
+
+ # for gs renderer
+ human_model_path: "./pretrained_models/human_model_files"
+ flame_subdivide_num: 1
+ flame_type: "flame"
+ gs_query_dim: 1024
+ gs_use_rgb: True
+ gs_sh: 3
+ gs_mlp_network_config:
+ n_neurons: 512
+ n_hidden_layers: 2
+ activation: silu
+ gs_xyz_offset_max_step: 0.2
+ gs_clip_scaling: 0.01
+ scale_sphere: false
+
+ expr_param_dim: 10
+ shape_param_dim: 10
+ add_teeth: false
+
+ fix_opacity: false
+ fix_rotation: false
+
+ has_disc: false
+
+ teeth_bs_flag: false
+ oral_mesh_flag: false
+
+dataset:
+ subsets:
+ - name: video_head
+ root_dirs: "./train_data/vfhq_vhap_nooffset/export"
+ meta_path:
+ train: "./train_data/vfhq_vhap_nooffset/label/valid_id_train_list.json"
+ val: "./train_data/vfhq_vhap_nooffset/label/valid_id_val_list.json"
+ sample_rate: 1.0
+ sample_side_views: 7
+ sample_aug_views: 0
+ source_image_res: 512
+ render_image:
+ low: 512
+ high: 512
+ region: null
+ num_train_workers: 4
+ num_val_workers: 2
+ pin_mem: true
+ repeat_num: 1
+ gaga_track_type: "vfhq"
+
+train:
+ mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE
+ find_unused_parameters: false
+ loss:
+ pixel_weight: 0.0
+ pixel_loss_fn: "mse"
+ crop_face_weight: 0.
+ crop_mouth_weight: 0.
+ crop_eye_weight: 0.
+ masked_pixel_weight: 1.0
+ perceptual_weight: 1.0
+ tv_weight: -1
+ mask_weight: 0:1.0:0.5:10000
+ offset_reg_weight: 0.1
+ optim:
+ lr: 4e-4
+ weight_decay: 0.05
+ beta1: 0.9
+ beta2: 0.95
+ clip_grad_norm: 1.0
+ scheduler:
+ type: cosine
+ warmup_real_iters: 3000
+ batch_size: 4 # REPLACE THIS (PER GPU)
+ accum_steps: 1 # REPLACE THIS
+ epochs: 100 # REPLACE THIS
+ debug_global_steps: null
+ resume: ""
+
+val:
+ batch_size: 2
+ global_step_period: 500
+ debug_batches: 10
+
+saver:
+ auto_resume: true
+ load_model: null
+ checkpoint_root: ./exps/checkpoints
+ checkpoint_global_steps: 500
+ checkpoint_keep_level: 5
+
+logger:
+ stream_level: WARNING
+ log_level: INFO
+ log_root: ./exps/logs
+ tracker_root: ./exps/trackers
+ enable_profiler: false
+ trackers:
+ - tensorboard
+ image_monitor:
+ train_global_steps: 500
+ samples_per_log: 4
+
+compile:
+ suppress_errors: true
+ print_specializations: true
+ disable: true
diff --git a/configs/stylematte_config.json b/configs/stylematte_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3ba17e5cefd8b9be3bc774edf7bda7663b30bd22
--- /dev/null
+++ b/configs/stylematte_config.json
@@ -0,0 +1,2311 @@
+{
+ "_commit_hash": null,
+ "activation_function": "relu",
+ "architectures": [
+ "Mask2FormerForUniversalSegmentation"
+ ],
+ "backbone_config": {
+ "_name_or_path": "",
+ "add_cross_attention": false,
+ "architectures": [
+ "SwinForImageClassification"
+ ],
+ "attention_probs_dropout_prob": 0.0,
+ "bad_words_ids": null,
+ "begin_suppress_tokens": null,
+ "bos_token_id": null,
+ "chunk_size_feed_forward": 0,
+ "cross_attention_hidden_size": null,
+ "decoder_start_token_id": null,
+ "depths": [
+ 2,
+ 2,
+ 6,
+ 2
+ ],
+ "diversity_penalty": 0.0,
+ "do_sample": false,
+ "drop_path_rate": 0.3,
+ "early_stopping": false,
+ "embed_dim": 96,
+ "encoder_no_repeat_ngram_size": 0,
+ "encoder_stride": 32,
+ "eos_token_id": null,
+ "exponential_decay_length_penalty": null,
+ "finetuning_task": null,
+ "forced_bos_token_id": null,
+ "forced_eos_token_id": null,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "hidden_size": 768,
+ "id2label": {
+ "0": "tench, Tinca tinca",
+ "1": "goldfish, Carassius auratus",
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
+ "3": "tiger shark, Galeocerdo cuvieri",
+ "4": "hammerhead, hammerhead shark",
+ "5": "electric ray, crampfish, numbfish, torpedo",
+ "6": "stingray",
+ "7": "cock",
+ "8": "hen",
+ "9": "ostrich, Struthio camelus",
+ "10": "brambling, Fringilla montifringilla",
+ "11": "goldfinch, Carduelis carduelis",
+ "12": "house finch, linnet, Carpodacus mexicanus",
+ "13": "junco, snowbird",
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
+ "15": "robin, American robin, Turdus migratorius",
+ "16": "bulbul",
+ "17": "jay",
+ "18": "magpie",
+ "19": "chickadee",
+ "20": "water ouzel, dipper",
+ "21": "kite",
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
+ "23": "vulture",
+ "24": "great grey owl, great gray owl, Strix nebulosa",
+ "25": "European fire salamander, Salamandra salamandra",
+ "26": "common newt, Triturus vulgaris",
+ "27": "eft",
+ "28": "spotted salamander, Ambystoma maculatum",
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
+ "30": "bullfrog, Rana catesbeiana",
+ "31": "tree frog, tree-frog",
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
+ "35": "mud turtle",
+ "36": "terrapin",
+ "37": "box turtle, box tortoise",
+ "38": "banded gecko",
+ "39": "common iguana, iguana, Iguana iguana",
+ "40": "American chameleon, anole, Anolis carolinensis",
+ "41": "whiptail, whiptail lizard",
+ "42": "agama",
+ "43": "frilled lizard, Chlamydosaurus kingi",
+ "44": "alligator lizard",
+ "45": "Gila monster, Heloderma suspectum",
+ "46": "green lizard, Lacerta viridis",
+ "47": "African chameleon, Chamaeleo chamaeleon",
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
+ "50": "American alligator, Alligator mississipiensis",
+ "51": "triceratops",
+ "52": "thunder snake, worm snake, Carphophis amoenus",
+ "53": "ringneck snake, ring-necked snake, ring snake",
+ "54": "hognose snake, puff adder, sand viper",
+ "55": "green snake, grass snake",
+ "56": "king snake, kingsnake",
+ "57": "garter snake, grass snake",
+ "58": "water snake",
+ "59": "vine snake",
+ "60": "night snake, Hypsiglena torquata",
+ "61": "boa constrictor, Constrictor constrictor",
+ "62": "rock python, rock snake, Python sebae",
+ "63": "Indian cobra, Naja naja",
+ "64": "green mamba",
+ "65": "sea snake",
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
+ "69": "trilobite",
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
+ "71": "scorpion",
+ "72": "black and gold garden spider, Argiope aurantia",
+ "73": "barn spider, Araneus cavaticus",
+ "74": "garden spider, Aranea diademata",
+ "75": "black widow, Latrodectus mactans",
+ "76": "tarantula",
+ "77": "wolf spider, hunting spider",
+ "78": "tick",
+ "79": "centipede",
+ "80": "black grouse",
+ "81": "ptarmigan",
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
+ "83": "prairie chicken, prairie grouse, prairie fowl",
+ "84": "peacock",
+ "85": "quail",
+ "86": "partridge",
+ "87": "African grey, African gray, Psittacus erithacus",
+ "88": "macaw",
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
+ "90": "lorikeet",
+ "91": "coucal",
+ "92": "bee eater",
+ "93": "hornbill",
+ "94": "hummingbird",
+ "95": "jacamar",
+ "96": "toucan",
+ "97": "drake",
+ "98": "red-breasted merganser, Mergus serrator",
+ "99": "goose",
+ "100": "black swan, Cygnus atratus",
+ "101": "tusker",
+ "102": "echidna, spiny anteater, anteater",
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
+ "104": "wallaby, brush kangaroo",
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
+ "106": "wombat",
+ "107": "jellyfish",
+ "108": "sea anemone, anemone",
+ "109": "brain coral",
+ "110": "flatworm, platyhelminth",
+ "111": "nematode, nematode worm, roundworm",
+ "112": "conch",
+ "113": "snail",
+ "114": "slug",
+ "115": "sea slug, nudibranch",
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
+ "117": "chambered nautilus, pearly nautilus, nautilus",
+ "118": "Dungeness crab, Cancer magister",
+ "119": "rock crab, Cancer irroratus",
+ "120": "fiddler crab",
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
+ "125": "hermit crab",
+ "126": "isopod",
+ "127": "white stork, Ciconia ciconia",
+ "128": "black stork, Ciconia nigra",
+ "129": "spoonbill",
+ "130": "flamingo",
+ "131": "little blue heron, Egretta caerulea",
+ "132": "American egret, great white heron, Egretta albus",
+ "133": "bittern",
+ "134": "crane",
+ "135": "limpkin, Aramus pictus",
+ "136": "European gallinule, Porphyrio porphyrio",
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
+ "138": "bustard",
+ "139": "ruddy turnstone, Arenaria interpres",
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
+ "141": "redshank, Tringa totanus",
+ "142": "dowitcher",
+ "143": "oystercatcher, oyster catcher",
+ "144": "pelican",
+ "145": "king penguin, Aptenodytes patagonica",
+ "146": "albatross, mollymawk",
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
+ "149": "dugong, Dugong dugon",
+ "150": "sea lion",
+ "151": "Chihuahua",
+ "152": "Japanese spaniel",
+ "153": "Maltese dog, Maltese terrier, Maltese",
+ "154": "Pekinese, Pekingese, Peke",
+ "155": "Shih-Tzu",
+ "156": "Blenheim spaniel",
+ "157": "papillon",
+ "158": "toy terrier",
+ "159": "Rhodesian ridgeback",
+ "160": "Afghan hound, Afghan",
+ "161": "basset, basset hound",
+ "162": "beagle",
+ "163": "bloodhound, sleuthhound",
+ "164": "bluetick",
+ "165": "black-and-tan coonhound",
+ "166": "Walker hound, Walker foxhound",
+ "167": "English foxhound",
+ "168": "redbone",
+ "169": "borzoi, Russian wolfhound",
+ "170": "Irish wolfhound",
+ "171": "Italian greyhound",
+ "172": "whippet",
+ "173": "Ibizan hound, Ibizan Podenco",
+ "174": "Norwegian elkhound, elkhound",
+ "175": "otterhound, otter hound",
+ "176": "Saluki, gazelle hound",
+ "177": "Scottish deerhound, deerhound",
+ "178": "Weimaraner",
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
+ "181": "Bedlington terrier",
+ "182": "Border terrier",
+ "183": "Kerry blue terrier",
+ "184": "Irish terrier",
+ "185": "Norfolk terrier",
+ "186": "Norwich terrier",
+ "187": "Yorkshire terrier",
+ "188": "wire-haired fox terrier",
+ "189": "Lakeland terrier",
+ "190": "Sealyham terrier, Sealyham",
+ "191": "Airedale, Airedale terrier",
+ "192": "cairn, cairn terrier",
+ "193": "Australian terrier",
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
+ "195": "Boston bull, Boston terrier",
+ "196": "miniature schnauzer",
+ "197": "giant schnauzer",
+ "198": "standard schnauzer",
+ "199": "Scotch terrier, Scottish terrier, Scottie",
+ "200": "Tibetan terrier, chrysanthemum dog",
+ "201": "silky terrier, Sydney silky",
+ "202": "soft-coated wheaten terrier",
+ "203": "West Highland white terrier",
+ "204": "Lhasa, Lhasa apso",
+ "205": "flat-coated retriever",
+ "206": "curly-coated retriever",
+ "207": "golden retriever",
+ "208": "Labrador retriever",
+ "209": "Chesapeake Bay retriever",
+ "210": "German short-haired pointer",
+ "211": "vizsla, Hungarian pointer",
+ "212": "English setter",
+ "213": "Irish setter, red setter",
+ "214": "Gordon setter",
+ "215": "Brittany spaniel",
+ "216": "clumber, clumber spaniel",
+ "217": "English springer, English springer spaniel",
+ "218": "Welsh springer spaniel",
+ "219": "cocker spaniel, English cocker spaniel, cocker",
+ "220": "Sussex spaniel",
+ "221": "Irish water spaniel",
+ "222": "kuvasz",
+ "223": "schipperke",
+ "224": "groenendael",
+ "225": "malinois",
+ "226": "briard",
+ "227": "kelpie",
+ "228": "komondor",
+ "229": "Old English sheepdog, bobtail",
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
+ "231": "collie",
+ "232": "Border collie",
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
+ "234": "Rottweiler",
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
+ "236": "Doberman, Doberman pinscher",
+ "237": "miniature pinscher",
+ "238": "Greater Swiss Mountain dog",
+ "239": "Bernese mountain dog",
+ "240": "Appenzeller",
+ "241": "EntleBucher",
+ "242": "boxer",
+ "243": "bull mastiff",
+ "244": "Tibetan mastiff",
+ "245": "French bulldog",
+ "246": "Great Dane",
+ "247": "Saint Bernard, St Bernard",
+ "248": "Eskimo dog, husky",
+ "249": "malamute, malemute, Alaskan malamute",
+ "250": "Siberian husky",
+ "251": "dalmatian, coach dog, carriage dog",
+ "252": "affenpinscher, monkey pinscher, monkey dog",
+ "253": "basenji",
+ "254": "pug, pug-dog",
+ "255": "Leonberg",
+ "256": "Newfoundland, Newfoundland dog",
+ "257": "Great Pyrenees",
+ "258": "Samoyed, Samoyede",
+ "259": "Pomeranian",
+ "260": "chow, chow chow",
+ "261": "keeshond",
+ "262": "Brabancon griffon",
+ "263": "Pembroke, Pembroke Welsh corgi",
+ "264": "Cardigan, Cardigan Welsh corgi",
+ "265": "toy poodle",
+ "266": "miniature poodle",
+ "267": "standard poodle",
+ "268": "Mexican hairless",
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
+ "273": "dingo, warrigal, warragal, Canis dingo",
+ "274": "dhole, Cuon alpinus",
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
+ "276": "hyena, hyaena",
+ "277": "red fox, Vulpes vulpes",
+ "278": "kit fox, Vulpes macrotis",
+ "279": "Arctic fox, white fox, Alopex lagopus",
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
+ "281": "tabby, tabby cat",
+ "282": "tiger cat",
+ "283": "Persian cat",
+ "284": "Siamese cat, Siamese",
+ "285": "Egyptian cat",
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
+ "287": "lynx, catamount",
+ "288": "leopard, Panthera pardus",
+ "289": "snow leopard, ounce, Panthera uncia",
+ "290": "jaguar, panther, Panthera onca, Felis onca",
+ "291": "lion, king of beasts, Panthera leo",
+ "292": "tiger, Panthera tigris",
+ "293": "cheetah, chetah, Acinonyx jubatus",
+ "294": "brown bear, bruin, Ursus arctos",
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
+ "298": "mongoose",
+ "299": "meerkat, mierkat",
+ "300": "tiger beetle",
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
+ "302": "ground beetle, carabid beetle",
+ "303": "long-horned beetle, longicorn, longicorn beetle",
+ "304": "leaf beetle, chrysomelid",
+ "305": "dung beetle",
+ "306": "rhinoceros beetle",
+ "307": "weevil",
+ "308": "fly",
+ "309": "bee",
+ "310": "ant, emmet, pismire",
+ "311": "grasshopper, hopper",
+ "312": "cricket",
+ "313": "walking stick, walkingstick, stick insect",
+ "314": "cockroach, roach",
+ "315": "mantis, mantid",
+ "316": "cicada, cicala",
+ "317": "leafhopper",
+ "318": "lacewing, lacewing fly",
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ "320": "damselfly",
+ "321": "admiral",
+ "322": "ringlet, ringlet butterfly",
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
+ "324": "cabbage butterfly",
+ "325": "sulphur butterfly, sulfur butterfly",
+ "326": "lycaenid, lycaenid butterfly",
+ "327": "starfish, sea star",
+ "328": "sea urchin",
+ "329": "sea cucumber, holothurian",
+ "330": "wood rabbit, cottontail, cottontail rabbit",
+ "331": "hare",
+ "332": "Angora, Angora rabbit",
+ "333": "hamster",
+ "334": "porcupine, hedgehog",
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
+ "336": "marmot",
+ "337": "beaver",
+ "338": "guinea pig, Cavia cobaya",
+ "339": "sorrel",
+ "340": "zebra",
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
+ "342": "wild boar, boar, Sus scrofa",
+ "343": "warthog",
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
+ "345": "ox",
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
+ "347": "bison",
+ "348": "ram, tup",
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
+ "350": "ibex, Capra ibex",
+ "351": "hartebeest",
+ "352": "impala, Aepyceros melampus",
+ "353": "gazelle",
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
+ "355": "llama",
+ "356": "weasel",
+ "357": "mink",
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
+ "359": "black-footed ferret, ferret, Mustela nigripes",
+ "360": "otter",
+ "361": "skunk, polecat, wood pussy",
+ "362": "badger",
+ "363": "armadillo",
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
+ "366": "gorilla, Gorilla gorilla",
+ "367": "chimpanzee, chimp, Pan troglodytes",
+ "368": "gibbon, Hylobates lar",
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
+ "370": "guenon, guenon monkey",
+ "371": "patas, hussar monkey, Erythrocebus patas",
+ "372": "baboon",
+ "373": "macaque",
+ "374": "langur",
+ "375": "colobus, colobus monkey",
+ "376": "proboscis monkey, Nasalis larvatus",
+ "377": "marmoset",
+ "378": "capuchin, ringtail, Cebus capucinus",
+ "379": "howler monkey, howler",
+ "380": "titi, titi monkey",
+ "381": "spider monkey, Ateles geoffroyi",
+ "382": "squirrel monkey, Saimiri sciureus",
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
+ "385": "Indian elephant, Elephas maximus",
+ "386": "African elephant, Loxodonta africana",
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
+ "389": "barracouta, snoek",
+ "390": "eel",
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
+ "392": "rock beauty, Holocanthus tricolor",
+ "393": "anemone fish",
+ "394": "sturgeon",
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
+ "396": "lionfish",
+ "397": "puffer, pufferfish, blowfish, globefish",
+ "398": "abacus",
+ "399": "abaya",
+ "400": "academic gown, academic robe, judge's robe",
+ "401": "accordion, piano accordion, squeeze box",
+ "402": "acoustic guitar",
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
+ "404": "airliner",
+ "405": "airship, dirigible",
+ "406": "altar",
+ "407": "ambulance",
+ "408": "amphibian, amphibious vehicle",
+ "409": "analog clock",
+ "410": "apiary, bee house",
+ "411": "apron",
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
+ "413": "assault rifle, assault gun",
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
+ "415": "bakery, bakeshop, bakehouse",
+ "416": "balance beam, beam",
+ "417": "balloon",
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
+ "419": "Band Aid",
+ "420": "banjo",
+ "421": "bannister, banister, balustrade, balusters, handrail",
+ "422": "barbell",
+ "423": "barber chair",
+ "424": "barbershop",
+ "425": "barn",
+ "426": "barometer",
+ "427": "barrel, cask",
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
+ "429": "baseball",
+ "430": "basketball",
+ "431": "bassinet",
+ "432": "bassoon",
+ "433": "bathing cap, swimming cap",
+ "434": "bath towel",
+ "435": "bathtub, bathing tub, bath, tub",
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
+ "437": "beacon, lighthouse, beacon light, pharos",
+ "438": "beaker",
+ "439": "bearskin, busby, shako",
+ "440": "beer bottle",
+ "441": "beer glass",
+ "442": "bell cote, bell cot",
+ "443": "bib",
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
+ "445": "bikini, two-piece",
+ "446": "binder, ring-binder",
+ "447": "binoculars, field glasses, opera glasses",
+ "448": "birdhouse",
+ "449": "boathouse",
+ "450": "bobsled, bobsleigh, bob",
+ "451": "bolo tie, bolo, bola tie, bola",
+ "452": "bonnet, poke bonnet",
+ "453": "bookcase",
+ "454": "bookshop, bookstore, bookstall",
+ "455": "bottlecap",
+ "456": "bow",
+ "457": "bow tie, bow-tie, bowtie",
+ "458": "brass, memorial tablet, plaque",
+ "459": "brassiere, bra, bandeau",
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
+ "461": "breastplate, aegis, egis",
+ "462": "broom",
+ "463": "bucket, pail",
+ "464": "buckle",
+ "465": "bulletproof vest",
+ "466": "bullet train, bullet",
+ "467": "butcher shop, meat market",
+ "468": "cab, hack, taxi, taxicab",
+ "469": "caldron, cauldron",
+ "470": "candle, taper, wax light",
+ "471": "cannon",
+ "472": "canoe",
+ "473": "can opener, tin opener",
+ "474": "cardigan",
+ "475": "car mirror",
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
+ "477": "carpenter's kit, tool kit",
+ "478": "carton",
+ "479": "car wheel",
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
+ "481": "cassette",
+ "482": "cassette player",
+ "483": "castle",
+ "484": "catamaran",
+ "485": "CD player",
+ "486": "cello, violoncello",
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
+ "488": "chain",
+ "489": "chainlink fence",
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
+ "491": "chain saw, chainsaw",
+ "492": "chest",
+ "493": "chiffonier, commode",
+ "494": "chime, bell, gong",
+ "495": "china cabinet, china closet",
+ "496": "Christmas stocking",
+ "497": "church, church building",
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
+ "499": "cleaver, meat cleaver, chopper",
+ "500": "cliff dwelling",
+ "501": "cloak",
+ "502": "clog, geta, patten, sabot",
+ "503": "cocktail shaker",
+ "504": "coffee mug",
+ "505": "coffeepot",
+ "506": "coil, spiral, volute, whorl, helix",
+ "507": "combination lock",
+ "508": "computer keyboard, keypad",
+ "509": "confectionery, confectionary, candy store",
+ "510": "container ship, containership, container vessel",
+ "511": "convertible",
+ "512": "corkscrew, bottle screw",
+ "513": "cornet, horn, trumpet, trump",
+ "514": "cowboy boot",
+ "515": "cowboy hat, ten-gallon hat",
+ "516": "cradle",
+ "517": "crane",
+ "518": "crash helmet",
+ "519": "crate",
+ "520": "crib, cot",
+ "521": "Crock Pot",
+ "522": "croquet ball",
+ "523": "crutch",
+ "524": "cuirass",
+ "525": "dam, dike, dyke",
+ "526": "desk",
+ "527": "desktop computer",
+ "528": "dial telephone, dial phone",
+ "529": "diaper, nappy, napkin",
+ "530": "digital clock",
+ "531": "digital watch",
+ "532": "dining table, board",
+ "533": "dishrag, dishcloth",
+ "534": "dishwasher, dish washer, dishwashing machine",
+ "535": "disk brake, disc brake",
+ "536": "dock, dockage, docking facility",
+ "537": "dogsled, dog sled, dog sleigh",
+ "538": "dome",
+ "539": "doormat, welcome mat",
+ "540": "drilling platform, offshore rig",
+ "541": "drum, membranophone, tympan",
+ "542": "drumstick",
+ "543": "dumbbell",
+ "544": "Dutch oven",
+ "545": "electric fan, blower",
+ "546": "electric guitar",
+ "547": "electric locomotive",
+ "548": "entertainment center",
+ "549": "envelope",
+ "550": "espresso maker",
+ "551": "face powder",
+ "552": "feather boa, boa",
+ "553": "file, file cabinet, filing cabinet",
+ "554": "fireboat",
+ "555": "fire engine, fire truck",
+ "556": "fire screen, fireguard",
+ "557": "flagpole, flagstaff",
+ "558": "flute, transverse flute",
+ "559": "folding chair",
+ "560": "football helmet",
+ "561": "forklift",
+ "562": "fountain",
+ "563": "fountain pen",
+ "564": "four-poster",
+ "565": "freight car",
+ "566": "French horn, horn",
+ "567": "frying pan, frypan, skillet",
+ "568": "fur coat",
+ "569": "garbage truck, dustcart",
+ "570": "gasmask, respirator, gas helmet",
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
+ "572": "goblet",
+ "573": "go-kart",
+ "574": "golf ball",
+ "575": "golfcart, golf cart",
+ "576": "gondola",
+ "577": "gong, tam-tam",
+ "578": "gown",
+ "579": "grand piano, grand",
+ "580": "greenhouse, nursery, glasshouse",
+ "581": "grille, radiator grille",
+ "582": "grocery store, grocery, food market, market",
+ "583": "guillotine",
+ "584": "hair slide",
+ "585": "hair spray",
+ "586": "half track",
+ "587": "hammer",
+ "588": "hamper",
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
+ "590": "hand-held computer, hand-held microcomputer",
+ "591": "handkerchief, hankie, hanky, hankey",
+ "592": "hard disc, hard disk, fixed disk",
+ "593": "harmonica, mouth organ, harp, mouth harp",
+ "594": "harp",
+ "595": "harvester, reaper",
+ "596": "hatchet",
+ "597": "holster",
+ "598": "home theater, home theatre",
+ "599": "honeycomb",
+ "600": "hook, claw",
+ "601": "hoopskirt, crinoline",
+ "602": "horizontal bar, high bar",
+ "603": "horse cart, horse-cart",
+ "604": "hourglass",
+ "605": "iPod",
+ "606": "iron, smoothing iron",
+ "607": "jack-o'-lantern",
+ "608": "jean, blue jean, denim",
+ "609": "jeep, landrover",
+ "610": "jersey, T-shirt, tee shirt",
+ "611": "jigsaw puzzle",
+ "612": "jinrikisha, ricksha, rickshaw",
+ "613": "joystick",
+ "614": "kimono",
+ "615": "knee pad",
+ "616": "knot",
+ "617": "lab coat, laboratory coat",
+ "618": "ladle",
+ "619": "lampshade, lamp shade",
+ "620": "laptop, laptop computer",
+ "621": "lawn mower, mower",
+ "622": "lens cap, lens cover",
+ "623": "letter opener, paper knife, paperknife",
+ "624": "library",
+ "625": "lifeboat",
+ "626": "lighter, light, igniter, ignitor",
+ "627": "limousine, limo",
+ "628": "liner, ocean liner",
+ "629": "lipstick, lip rouge",
+ "630": "Loafer",
+ "631": "lotion",
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
+ "633": "loupe, jeweler's loupe",
+ "634": "lumbermill, sawmill",
+ "635": "magnetic compass",
+ "636": "mailbag, postbag",
+ "637": "mailbox, letter box",
+ "638": "maillot",
+ "639": "maillot, tank suit",
+ "640": "manhole cover",
+ "641": "maraca",
+ "642": "marimba, xylophone",
+ "643": "mask",
+ "644": "matchstick",
+ "645": "maypole",
+ "646": "maze, labyrinth",
+ "647": "measuring cup",
+ "648": "medicine chest, medicine cabinet",
+ "649": "megalith, megalithic structure",
+ "650": "microphone, mike",
+ "651": "microwave, microwave oven",
+ "652": "military uniform",
+ "653": "milk can",
+ "654": "minibus",
+ "655": "miniskirt, mini",
+ "656": "minivan",
+ "657": "missile",
+ "658": "mitten",
+ "659": "mixing bowl",
+ "660": "mobile home, manufactured home",
+ "661": "Model T",
+ "662": "modem",
+ "663": "monastery",
+ "664": "monitor",
+ "665": "moped",
+ "666": "mortar",
+ "667": "mortarboard",
+ "668": "mosque",
+ "669": "mosquito net",
+ "670": "motor scooter, scooter",
+ "671": "mountain bike, all-terrain bike, off-roader",
+ "672": "mountain tent",
+ "673": "mouse, computer mouse",
+ "674": "mousetrap",
+ "675": "moving van",
+ "676": "muzzle",
+ "677": "nail",
+ "678": "neck brace",
+ "679": "necklace",
+ "680": "nipple",
+ "681": "notebook, notebook computer",
+ "682": "obelisk",
+ "683": "oboe, hautboy, hautbois",
+ "684": "ocarina, sweet potato",
+ "685": "odometer, hodometer, mileometer, milometer",
+ "686": "oil filter",
+ "687": "organ, pipe organ",
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
+ "689": "overskirt",
+ "690": "oxcart",
+ "691": "oxygen mask",
+ "692": "packet",
+ "693": "paddle, boat paddle",
+ "694": "paddlewheel, paddle wheel",
+ "695": "padlock",
+ "696": "paintbrush",
+ "697": "pajama, pyjama, pj's, jammies",
+ "698": "palace",
+ "699": "panpipe, pandean pipe, syrinx",
+ "700": "paper towel",
+ "701": "parachute, chute",
+ "702": "parallel bars, bars",
+ "703": "park bench",
+ "704": "parking meter",
+ "705": "passenger car, coach, carriage",
+ "706": "patio, terrace",
+ "707": "pay-phone, pay-station",
+ "708": "pedestal, plinth, footstall",
+ "709": "pencil box, pencil case",
+ "710": "pencil sharpener",
+ "711": "perfume, essence",
+ "712": "Petri dish",
+ "713": "photocopier",
+ "714": "pick, plectrum, plectron",
+ "715": "pickelhaube",
+ "716": "picket fence, paling",
+ "717": "pickup, pickup truck",
+ "718": "pier",
+ "719": "piggy bank, penny bank",
+ "720": "pill bottle",
+ "721": "pillow",
+ "722": "ping-pong ball",
+ "723": "pinwheel",
+ "724": "pirate, pirate ship",
+ "725": "pitcher, ewer",
+ "726": "plane, carpenter's plane, woodworking plane",
+ "727": "planetarium",
+ "728": "plastic bag",
+ "729": "plate rack",
+ "730": "plow, plough",
+ "731": "plunger, plumber's helper",
+ "732": "Polaroid camera, Polaroid Land camera",
+ "733": "pole",
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
+ "735": "poncho",
+ "736": "pool table, billiard table, snooker table",
+ "737": "pop bottle, soda bottle",
+ "738": "pot, flowerpot",
+ "739": "potter's wheel",
+ "740": "power drill",
+ "741": "prayer rug, prayer mat",
+ "742": "printer",
+ "743": "prison, prison house",
+ "744": "projectile, missile",
+ "745": "projector",
+ "746": "puck, hockey puck",
+ "747": "punching bag, punch bag, punching ball, punchball",
+ "748": "purse",
+ "749": "quill, quill pen",
+ "750": "quilt, comforter, comfort, puff",
+ "751": "racer, race car, racing car",
+ "752": "racket, racquet",
+ "753": "radiator",
+ "754": "radio, wireless",
+ "755": "radio telescope, radio reflector",
+ "756": "rain barrel",
+ "757": "recreational vehicle, RV, R.V.",
+ "758": "reel",
+ "759": "reflex camera",
+ "760": "refrigerator, icebox",
+ "761": "remote control, remote",
+ "762": "restaurant, eating house, eating place, eatery",
+ "763": "revolver, six-gun, six-shooter",
+ "764": "rifle",
+ "765": "rocking chair, rocker",
+ "766": "rotisserie",
+ "767": "rubber eraser, rubber, pencil eraser",
+ "768": "rugby ball",
+ "769": "rule, ruler",
+ "770": "running shoe",
+ "771": "safe",
+ "772": "safety pin",
+ "773": "saltshaker, salt shaker",
+ "774": "sandal",
+ "775": "sarong",
+ "776": "sax, saxophone",
+ "777": "scabbard",
+ "778": "scale, weighing machine",
+ "779": "school bus",
+ "780": "schooner",
+ "781": "scoreboard",
+ "782": "screen, CRT screen",
+ "783": "screw",
+ "784": "screwdriver",
+ "785": "seat belt, seatbelt",
+ "786": "sewing machine",
+ "787": "shield, buckler",
+ "788": "shoe shop, shoe-shop, shoe store",
+ "789": "shoji",
+ "790": "shopping basket",
+ "791": "shopping cart",
+ "792": "shovel",
+ "793": "shower cap",
+ "794": "shower curtain",
+ "795": "ski",
+ "796": "ski mask",
+ "797": "sleeping bag",
+ "798": "slide rule, slipstick",
+ "799": "sliding door",
+ "800": "slot, one-armed bandit",
+ "801": "snorkel",
+ "802": "snowmobile",
+ "803": "snowplow, snowplough",
+ "804": "soap dispenser",
+ "805": "soccer ball",
+ "806": "sock",
+ "807": "solar dish, solar collector, solar furnace",
+ "808": "sombrero",
+ "809": "soup bowl",
+ "810": "space bar",
+ "811": "space heater",
+ "812": "space shuttle",
+ "813": "spatula",
+ "814": "speedboat",
+ "815": "spider web, spider's web",
+ "816": "spindle",
+ "817": "sports car, sport car",
+ "818": "spotlight, spot",
+ "819": "stage",
+ "820": "steam locomotive",
+ "821": "steel arch bridge",
+ "822": "steel drum",
+ "823": "stethoscope",
+ "824": "stole",
+ "825": "stone wall",
+ "826": "stopwatch, stop watch",
+ "827": "stove",
+ "828": "strainer",
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
+ "830": "stretcher",
+ "831": "studio couch, day bed",
+ "832": "stupa, tope",
+ "833": "submarine, pigboat, sub, U-boat",
+ "834": "suit, suit of clothes",
+ "835": "sundial",
+ "836": "sunglass",
+ "837": "sunglasses, dark glasses, shades",
+ "838": "sunscreen, sunblock, sun blocker",
+ "839": "suspension bridge",
+ "840": "swab, swob, mop",
+ "841": "sweatshirt",
+ "842": "swimming trunks, bathing trunks",
+ "843": "swing",
+ "844": "switch, electric switch, electrical switch",
+ "845": "syringe",
+ "846": "table lamp",
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
+ "848": "tape player",
+ "849": "teapot",
+ "850": "teddy, teddy bear",
+ "851": "television, television system",
+ "852": "tennis ball",
+ "853": "thatch, thatched roof",
+ "854": "theater curtain, theatre curtain",
+ "855": "thimble",
+ "856": "thresher, thrasher, threshing machine",
+ "857": "throne",
+ "858": "tile roof",
+ "859": "toaster",
+ "860": "tobacco shop, tobacconist shop, tobacconist",
+ "861": "toilet seat",
+ "862": "torch",
+ "863": "totem pole",
+ "864": "tow truck, tow car, wrecker",
+ "865": "toyshop",
+ "866": "tractor",
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
+ "868": "tray",
+ "869": "trench coat",
+ "870": "tricycle, trike, velocipede",
+ "871": "trimaran",
+ "872": "tripod",
+ "873": "triumphal arch",
+ "874": "trolleybus, trolley coach, trackless trolley",
+ "875": "trombone",
+ "876": "tub, vat",
+ "877": "turnstile",
+ "878": "typewriter keyboard",
+ "879": "umbrella",
+ "880": "unicycle, monocycle",
+ "881": "upright, upright piano",
+ "882": "vacuum, vacuum cleaner",
+ "883": "vase",
+ "884": "vault",
+ "885": "velvet",
+ "886": "vending machine",
+ "887": "vestment",
+ "888": "viaduct",
+ "889": "violin, fiddle",
+ "890": "volleyball",
+ "891": "waffle iron",
+ "892": "wall clock",
+ "893": "wallet, billfold, notecase, pocketbook",
+ "894": "wardrobe, closet, press",
+ "895": "warplane, military plane",
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
+ "897": "washer, automatic washer, washing machine",
+ "898": "water bottle",
+ "899": "water jug",
+ "900": "water tower",
+ "901": "whiskey jug",
+ "902": "whistle",
+ "903": "wig",
+ "904": "window screen",
+ "905": "window shade",
+ "906": "Windsor tie",
+ "907": "wine bottle",
+ "908": "wing",
+ "909": "wok",
+ "910": "wooden spoon",
+ "911": "wool, woolen, woollen",
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
+ "913": "wreck",
+ "914": "yawl",
+ "915": "yurt",
+ "916": "web site, website, internet site, site",
+ "917": "comic book",
+ "918": "crossword puzzle, crossword",
+ "919": "street sign",
+ "920": "traffic light, traffic signal, stoplight",
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
+ "922": "menu",
+ "923": "plate",
+ "924": "guacamole",
+ "925": "consomme",
+ "926": "hot pot, hotpot",
+ "927": "trifle",
+ "928": "ice cream, icecream",
+ "929": "ice lolly, lolly, lollipop, popsicle",
+ "930": "French loaf",
+ "931": "bagel, beigel",
+ "932": "pretzel",
+ "933": "cheeseburger",
+ "934": "hotdog, hot dog, red hot",
+ "935": "mashed potato",
+ "936": "head cabbage",
+ "937": "broccoli",
+ "938": "cauliflower",
+ "939": "zucchini, courgette",
+ "940": "spaghetti squash",
+ "941": "acorn squash",
+ "942": "butternut squash",
+ "943": "cucumber, cuke",
+ "944": "artichoke, globe artichoke",
+ "945": "bell pepper",
+ "946": "cardoon",
+ "947": "mushroom",
+ "948": "Granny Smith",
+ "949": "strawberry",
+ "950": "orange",
+ "951": "lemon",
+ "952": "fig",
+ "953": "pineapple, ananas",
+ "954": "banana",
+ "955": "jackfruit, jak, jack",
+ "956": "custard apple",
+ "957": "pomegranate",
+ "958": "hay",
+ "959": "carbonara",
+ "960": "chocolate sauce, chocolate syrup",
+ "961": "dough",
+ "962": "meat loaf, meatloaf",
+ "963": "pizza, pizza pie",
+ "964": "potpie",
+ "965": "burrito",
+ "966": "red wine",
+ "967": "espresso",
+ "968": "cup",
+ "969": "eggnog",
+ "970": "alp",
+ "971": "bubble",
+ "972": "cliff, drop, drop-off",
+ "973": "coral reef",
+ "974": "geyser",
+ "975": "lakeside, lakeshore",
+ "976": "promontory, headland, head, foreland",
+ "977": "sandbar, sand bar",
+ "978": "seashore, coast, seacoast, sea-coast",
+ "979": "valley, vale",
+ "980": "volcano",
+ "981": "ballplayer, baseball player",
+ "982": "groom, bridegroom",
+ "983": "scuba diver",
+ "984": "rapeseed",
+ "985": "daisy",
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ "987": "corn",
+ "988": "acorn",
+ "989": "hip, rose hip, rosehip",
+ "990": "buckeye, horse chestnut, conker",
+ "991": "coral fungus",
+ "992": "agaric",
+ "993": "gyromitra",
+ "994": "stinkhorn, carrion fungus",
+ "995": "earthstar",
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
+ "997": "bolete",
+ "998": "ear, spike, capitulum",
+ "999": "toilet tissue, toilet paper, bathroom tissue"
+ },
+ "image_size": 224,
+ "initializer_range": 0.02,
+ "is_decoder": false,
+ "is_encoder_decoder": false,
+ "label2id": {
+ "Afghan hound, Afghan": 160,
+ "African chameleon, Chamaeleo chamaeleon": 47,
+ "African crocodile, Nile crocodile, Crocodylus niloticus": 49,
+ "African elephant, Loxodonta africana": 386,
+ "African grey, African gray, Psittacus erithacus": 87,
+ "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
+ "Airedale, Airedale terrier": 191,
+ "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
+ "American alligator, Alligator mississipiensis": 50,
+ "American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
+ "American chameleon, anole, Anolis carolinensis": 40,
+ "American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
+ "American egret, great white heron, Egretta albus": 132,
+ "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
+ "Angora, Angora rabbit": 332,
+ "Appenzeller": 240,
+ "Arabian camel, dromedary, Camelus dromedarius": 354,
+ "Arctic fox, white fox, Alopex lagopus": 279,
+ "Australian terrier": 193,
+ "Band Aid": 419,
+ "Bedlington terrier": 181,
+ "Bernese mountain dog": 239,
+ "Blenheim spaniel": 156,
+ "Border collie": 232,
+ "Border terrier": 182,
+ "Boston bull, Boston terrier": 195,
+ "Bouvier des Flandres, Bouviers des Flandres": 233,
+ "Brabancon griffon": 262,
+ "Brittany spaniel": 215,
+ "CD player": 485,
+ "Cardigan, Cardigan Welsh corgi": 264,
+ "Chesapeake Bay retriever": 209,
+ "Chihuahua": 151,
+ "Christmas stocking": 496,
+ "Crock Pot": 521,
+ "Dandie Dinmont, Dandie Dinmont terrier": 194,
+ "Doberman, Doberman pinscher": 236,
+ "Dungeness crab, Cancer magister": 118,
+ "Dutch oven": 544,
+ "Egyptian cat": 285,
+ "English foxhound": 167,
+ "English setter": 212,
+ "English springer, English springer spaniel": 217,
+ "EntleBucher": 241,
+ "Eskimo dog, husky": 248,
+ "European fire salamander, Salamandra salamandra": 25,
+ "European gallinule, Porphyrio porphyrio": 136,
+ "French bulldog": 245,
+ "French horn, horn": 566,
+ "French loaf": 930,
+ "German shepherd, German shepherd dog, German police dog, alsatian": 235,
+ "German short-haired pointer": 210,
+ "Gila monster, Heloderma suspectum": 45,
+ "Gordon setter": 214,
+ "Granny Smith": 948,
+ "Great Dane": 246,
+ "Great Pyrenees": 257,
+ "Greater Swiss Mountain dog": 238,
+ "Ibizan hound, Ibizan Podenco": 173,
+ "Indian cobra, Naja naja": 63,
+ "Indian elephant, Elephas maximus": 385,
+ "Irish setter, red setter": 213,
+ "Irish terrier": 184,
+ "Irish water spaniel": 221,
+ "Irish wolfhound": 170,
+ "Italian greyhound": 171,
+ "Japanese spaniel": 152,
+ "Kerry blue terrier": 183,
+ "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
+ "Labrador retriever": 208,
+ "Lakeland terrier": 189,
+ "Leonberg": 255,
+ "Lhasa, Lhasa apso": 204,
+ "Loafer": 630,
+ "Madagascar cat, ring-tailed lemur, Lemur catta": 383,
+ "Maltese dog, Maltese terrier, Maltese": 153,
+ "Mexican hairless": 268,
+ "Model T": 661,
+ "Newfoundland, Newfoundland dog": 256,
+ "Norfolk terrier": 185,
+ "Norwegian elkhound, elkhound": 174,
+ "Norwich terrier": 186,
+ "Old English sheepdog, bobtail": 229,
+ "Pekinese, Pekingese, Peke": 154,
+ "Pembroke, Pembroke Welsh corgi": 263,
+ "Persian cat": 283,
+ "Petri dish": 712,
+ "Polaroid camera, Polaroid Land camera": 732,
+ "Pomeranian": 259,
+ "Rhodesian ridgeback": 159,
+ "Rottweiler": 234,
+ "Saint Bernard, St Bernard": 247,
+ "Saluki, gazelle hound": 176,
+ "Samoyed, Samoyede": 258,
+ "Scotch terrier, Scottish terrier, Scottie": 199,
+ "Scottish deerhound, deerhound": 177,
+ "Sealyham terrier, Sealyham": 190,
+ "Shetland sheepdog, Shetland sheep dog, Shetland": 230,
+ "Shih-Tzu": 155,
+ "Siamese cat, Siamese": 284,
+ "Siberian husky": 250,
+ "Staffordshire bullterrier, Staffordshire bull terrier": 179,
+ "Sussex spaniel": 220,
+ "Tibetan mastiff": 244,
+ "Tibetan terrier, chrysanthemum dog": 200,
+ "Walker hound, Walker foxhound": 166,
+ "Weimaraner": 178,
+ "Welsh springer spaniel": 218,
+ "West Highland white terrier": 203,
+ "Windsor tie": 906,
+ "Yorkshire terrier": 187,
+ "abacus": 398,
+ "abaya": 399,
+ "academic gown, academic robe, judge's robe": 400,
+ "accordion, piano accordion, squeeze box": 401,
+ "acorn": 988,
+ "acorn squash": 941,
+ "acoustic guitar": 402,
+ "admiral": 321,
+ "affenpinscher, monkey pinscher, monkey dog": 252,
+ "agama": 42,
+ "agaric": 992,
+ "aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
+ "airliner": 404,
+ "airship, dirigible": 405,
+ "albatross, mollymawk": 146,
+ "alligator lizard": 44,
+ "alp": 970,
+ "altar": 406,
+ "ambulance": 407,
+ "amphibian, amphibious vehicle": 408,
+ "analog clock": 409,
+ "anemone fish": 393,
+ "ant, emmet, pismire": 310,
+ "apiary, bee house": 410,
+ "apron": 411,
+ "armadillo": 363,
+ "artichoke, globe artichoke": 944,
+ "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
+ "assault rifle, assault gun": 413,
+ "axolotl, mud puppy, Ambystoma mexicanum": 29,
+ "baboon": 372,
+ "backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
+ "badger": 362,
+ "bagel, beigel": 931,
+ "bakery, bakeshop, bakehouse": 415,
+ "balance beam, beam": 416,
+ "bald eagle, American eagle, Haliaeetus leucocephalus": 22,
+ "balloon": 417,
+ "ballplayer, baseball player": 981,
+ "ballpoint, ballpoint pen, ballpen, Biro": 418,
+ "banana": 954,
+ "banded gecko": 38,
+ "banjo": 420,
+ "bannister, banister, balustrade, balusters, handrail": 421,
+ "barbell": 422,
+ "barber chair": 423,
+ "barbershop": 424,
+ "barn": 425,
+ "barn spider, Araneus cavaticus": 73,
+ "barometer": 426,
+ "barracouta, snoek": 389,
+ "barrel, cask": 427,
+ "barrow, garden cart, lawn cart, wheelbarrow": 428,
+ "baseball": 429,
+ "basenji": 253,
+ "basketball": 430,
+ "basset, basset hound": 161,
+ "bassinet": 431,
+ "bassoon": 432,
+ "bath towel": 434,
+ "bathing cap, swimming cap": 433,
+ "bathtub, bathing tub, bath, tub": 435,
+ "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
+ "beacon, lighthouse, beacon light, pharos": 437,
+ "beagle": 162,
+ "beaker": 438,
+ "bearskin, busby, shako": 439,
+ "beaver": 337,
+ "bee": 309,
+ "bee eater": 92,
+ "beer bottle": 440,
+ "beer glass": 441,
+ "bell cote, bell cot": 442,
+ "bell pepper": 945,
+ "bib": 443,
+ "bicycle-built-for-two, tandem bicycle, tandem": 444,
+ "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
+ "bikini, two-piece": 445,
+ "binder, ring-binder": 446,
+ "binoculars, field glasses, opera glasses": 447,
+ "birdhouse": 448,
+ "bison": 347,
+ "bittern": 133,
+ "black and gold garden spider, Argiope aurantia": 72,
+ "black grouse": 80,
+ "black stork, Ciconia nigra": 128,
+ "black swan, Cygnus atratus": 100,
+ "black widow, Latrodectus mactans": 75,
+ "black-and-tan coonhound": 165,
+ "black-footed ferret, ferret, Mustela nigripes": 359,
+ "bloodhound, sleuthhound": 163,
+ "bluetick": 164,
+ "boa constrictor, Constrictor constrictor": 61,
+ "boathouse": 449,
+ "bobsled, bobsleigh, bob": 450,
+ "bolete": 997,
+ "bolo tie, bolo, bola tie, bola": 451,
+ "bonnet, poke bonnet": 452,
+ "book jacket, dust cover, dust jacket, dust wrapper": 921,
+ "bookcase": 453,
+ "bookshop, bookstore, bookstall": 454,
+ "borzoi, Russian wolfhound": 169,
+ "bottlecap": 455,
+ "bow": 456,
+ "bow tie, bow-tie, bowtie": 457,
+ "box turtle, box tortoise": 37,
+ "boxer": 242,
+ "brain coral": 109,
+ "brambling, Fringilla montifringilla": 10,
+ "brass, memorial tablet, plaque": 458,
+ "brassiere, bra, bandeau": 459,
+ "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
+ "breastplate, aegis, egis": 461,
+ "briard": 226,
+ "broccoli": 937,
+ "broom": 462,
+ "brown bear, bruin, Ursus arctos": 294,
+ "bubble": 971,
+ "bucket, pail": 463,
+ "buckeye, horse chestnut, conker": 990,
+ "buckle": 464,
+ "bulbul": 16,
+ "bull mastiff": 243,
+ "bullet train, bullet": 466,
+ "bulletproof vest": 465,
+ "bullfrog, Rana catesbeiana": 30,
+ "burrito": 965,
+ "bustard": 138,
+ "butcher shop, meat market": 467,
+ "butternut squash": 942,
+ "cab, hack, taxi, taxicab": 468,
+ "cabbage butterfly": 324,
+ "cairn, cairn terrier": 192,
+ "caldron, cauldron": 469,
+ "can opener, tin opener": 473,
+ "candle, taper, wax light": 470,
+ "cannon": 471,
+ "canoe": 472,
+ "capuchin, ringtail, Cebus capucinus": 378,
+ "car mirror": 475,
+ "car wheel": 479,
+ "carbonara": 959,
+ "cardigan": 474,
+ "cardoon": 946,
+ "carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
+ "carpenter's kit, tool kit": 477,
+ "carton": 478,
+ "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
+ "cassette": 481,
+ "cassette player": 482,
+ "castle": 483,
+ "catamaran": 484,
+ "cauliflower": 938,
+ "cello, violoncello": 486,
+ "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
+ "centipede": 79,
+ "chain": 488,
+ "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
+ "chain saw, chainsaw": 491,
+ "chainlink fence": 489,
+ "chambered nautilus, pearly nautilus, nautilus": 117,
+ "cheeseburger": 933,
+ "cheetah, chetah, Acinonyx jubatus": 293,
+ "chest": 492,
+ "chickadee": 19,
+ "chiffonier, commode": 493,
+ "chime, bell, gong": 494,
+ "chimpanzee, chimp, Pan troglodytes": 367,
+ "china cabinet, china closet": 495,
+ "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
+ "chocolate sauce, chocolate syrup": 960,
+ "chow, chow chow": 260,
+ "church, church building": 497,
+ "cicada, cicala": 316,
+ "cinema, movie theater, movie theatre, movie house, picture palace": 498,
+ "cleaver, meat cleaver, chopper": 499,
+ "cliff dwelling": 500,
+ "cliff, drop, drop-off": 972,
+ "cloak": 501,
+ "clog, geta, patten, sabot": 502,
+ "clumber, clumber spaniel": 216,
+ "cock": 7,
+ "cocker spaniel, English cocker spaniel, cocker": 219,
+ "cockroach, roach": 314,
+ "cocktail shaker": 503,
+ "coffee mug": 504,
+ "coffeepot": 505,
+ "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
+ "coil, spiral, volute, whorl, helix": 506,
+ "collie": 231,
+ "colobus, colobus monkey": 375,
+ "combination lock": 507,
+ "comic book": 917,
+ "common iguana, iguana, Iguana iguana": 39,
+ "common newt, Triturus vulgaris": 26,
+ "computer keyboard, keypad": 508,
+ "conch": 112,
+ "confectionery, confectionary, candy store": 509,
+ "consomme": 925,
+ "container ship, containership, container vessel": 510,
+ "convertible": 511,
+ "coral fungus": 991,
+ "coral reef": 973,
+ "corkscrew, bottle screw": 512,
+ "corn": 987,
+ "cornet, horn, trumpet, trump": 513,
+ "coucal": 91,
+ "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
+ "cowboy boot": 514,
+ "cowboy hat, ten-gallon hat": 515,
+ "coyote, prairie wolf, brush wolf, Canis latrans": 272,
+ "cradle": 516,
+ "crane": 517,
+ "crash helmet": 518,
+ "crate": 519,
+ "crayfish, crawfish, crawdad, crawdaddy": 124,
+ "crib, cot": 520,
+ "cricket": 312,
+ "croquet ball": 522,
+ "crossword puzzle, crossword": 918,
+ "crutch": 523,
+ "cucumber, cuke": 943,
+ "cuirass": 524,
+ "cup": 968,
+ "curly-coated retriever": 206,
+ "custard apple": 956,
+ "daisy": 985,
+ "dalmatian, coach dog, carriage dog": 251,
+ "dam, dike, dyke": 525,
+ "damselfly": 320,
+ "desk": 526,
+ "desktop computer": 527,
+ "dhole, Cuon alpinus": 274,
+ "dial telephone, dial phone": 528,
+ "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
+ "diaper, nappy, napkin": 529,
+ "digital clock": 530,
+ "digital watch": 531,
+ "dingo, warrigal, warragal, Canis dingo": 273,
+ "dining table, board": 532,
+ "dishrag, dishcloth": 533,
+ "dishwasher, dish washer, dishwashing machine": 534,
+ "disk brake, disc brake": 535,
+ "dock, dockage, docking facility": 536,
+ "dogsled, dog sled, dog sleigh": 537,
+ "dome": 538,
+ "doormat, welcome mat": 539,
+ "dough": 961,
+ "dowitcher": 142,
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
+ "drake": 97,
+ "drilling platform, offshore rig": 540,
+ "drum, membranophone, tympan": 541,
+ "drumstick": 542,
+ "dugong, Dugong dugon": 149,
+ "dumbbell": 543,
+ "dung beetle": 305,
+ "ear, spike, capitulum": 998,
+ "earthstar": 995,
+ "echidna, spiny anteater, anteater": 102,
+ "eel": 390,
+ "eft": 27,
+ "eggnog": 969,
+ "electric fan, blower": 545,
+ "electric guitar": 546,
+ "electric locomotive": 547,
+ "electric ray, crampfish, numbfish, torpedo": 5,
+ "entertainment center": 548,
+ "envelope": 549,
+ "espresso": 967,
+ "espresso maker": 550,
+ "face powder": 551,
+ "feather boa, boa": 552,
+ "fiddler crab": 120,
+ "fig": 952,
+ "file, file cabinet, filing cabinet": 553,
+ "fire engine, fire truck": 555,
+ "fire screen, fireguard": 556,
+ "fireboat": 554,
+ "flagpole, flagstaff": 557,
+ "flamingo": 130,
+ "flat-coated retriever": 205,
+ "flatworm, platyhelminth": 110,
+ "flute, transverse flute": 558,
+ "fly": 308,
+ "folding chair": 559,
+ "football helmet": 560,
+ "forklift": 561,
+ "fountain": 562,
+ "fountain pen": 563,
+ "four-poster": 564,
+ "fox squirrel, eastern fox squirrel, Sciurus niger": 335,
+ "freight car": 565,
+ "frilled lizard, Chlamydosaurus kingi": 43,
+ "frying pan, frypan, skillet": 567,
+ "fur coat": 568,
+ "gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
+ "garbage truck, dustcart": 569,
+ "garden spider, Aranea diademata": 74,
+ "garter snake, grass snake": 57,
+ "gas pump, gasoline pump, petrol pump, island dispenser": 571,
+ "gasmask, respirator, gas helmet": 570,
+ "gazelle": 353,
+ "geyser": 974,
+ "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
+ "giant schnauzer": 197,
+ "gibbon, Hylobates lar": 368,
+ "go-kart": 573,
+ "goblet": 572,
+ "golden retriever": 207,
+ "goldfinch, Carduelis carduelis": 11,
+ "goldfish, Carassius auratus": 1,
+ "golf ball": 574,
+ "golfcart, golf cart": 575,
+ "gondola": 576,
+ "gong, tam-tam": 577,
+ "goose": 99,
+ "gorilla, Gorilla gorilla": 366,
+ "gown": 578,
+ "grand piano, grand": 579,
+ "grasshopper, hopper": 311,
+ "great grey owl, great gray owl, Strix nebulosa": 24,
+ "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
+ "green lizard, Lacerta viridis": 46,
+ "green mamba": 64,
+ "green snake, grass snake": 55,
+ "greenhouse, nursery, glasshouse": 580,
+ "grey fox, gray fox, Urocyon cinereoargenteus": 280,
+ "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
+ "grille, radiator grille": 581,
+ "grocery store, grocery, food market, market": 582,
+ "groenendael": 224,
+ "groom, bridegroom": 982,
+ "ground beetle, carabid beetle": 302,
+ "guacamole": 924,
+ "guenon, guenon monkey": 370,
+ "guillotine": 583,
+ "guinea pig, Cavia cobaya": 338,
+ "gyromitra": 993,
+ "hair slide": 584,
+ "hair spray": 585,
+ "half track": 586,
+ "hammer": 587,
+ "hammerhead, hammerhead shark": 4,
+ "hamper": 588,
+ "hamster": 333,
+ "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
+ "hand-held computer, hand-held microcomputer": 590,
+ "handkerchief, hankie, hanky, hankey": 591,
+ "hard disc, hard disk, fixed disk": 592,
+ "hare": 331,
+ "harmonica, mouth organ, harp, mouth harp": 593,
+ "harp": 594,
+ "hartebeest": 351,
+ "harvester, reaper": 595,
+ "harvestman, daddy longlegs, Phalangium opilio": 70,
+ "hatchet": 596,
+ "hay": 958,
+ "head cabbage": 936,
+ "hen": 8,
+ "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
+ "hermit crab": 125,
+ "hip, rose hip, rosehip": 989,
+ "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
+ "hog, pig, grunter, squealer, Sus scrofa": 341,
+ "hognose snake, puff adder, sand viper": 54,
+ "holster": 597,
+ "home theater, home theatre": 598,
+ "honeycomb": 599,
+ "hook, claw": 600,
+ "hoopskirt, crinoline": 601,
+ "horizontal bar, high bar": 602,
+ "hornbill": 93,
+ "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
+ "horse cart, horse-cart": 603,
+ "hot pot, hotpot": 926,
+ "hotdog, hot dog, red hot": 934,
+ "hourglass": 604,
+ "house finch, linnet, Carpodacus mexicanus": 12,
+ "howler monkey, howler": 379,
+ "hummingbird": 94,
+ "hyena, hyaena": 276,
+ "iPod": 605,
+ "ibex, Capra ibex": 350,
+ "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
+ "ice cream, icecream": 928,
+ "ice lolly, lolly, lollipop, popsicle": 929,
+ "impala, Aepyceros melampus": 352,
+ "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
+ "indri, indris, Indri indri, Indri brevicaudatus": 384,
+ "iron, smoothing iron": 606,
+ "isopod": 126,
+ "jacamar": 95,
+ "jack-o'-lantern": 607,
+ "jackfruit, jak, jack": 955,
+ "jaguar, panther, Panthera onca, Felis onca": 290,
+ "jay": 17,
+ "jean, blue jean, denim": 608,
+ "jeep, landrover": 609,
+ "jellyfish": 107,
+ "jersey, T-shirt, tee shirt": 610,
+ "jigsaw puzzle": 611,
+ "jinrikisha, ricksha, rickshaw": 612,
+ "joystick": 613,
+ "junco, snowbird": 13,
+ "keeshond": 261,
+ "kelpie": 227,
+ "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
+ "kimono": 614,
+ "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
+ "king penguin, Aptenodytes patagonica": 145,
+ "king snake, kingsnake": 56,
+ "kit fox, Vulpes macrotis": 278,
+ "kite": 21,
+ "knee pad": 615,
+ "knot": 616,
+ "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
+ "komondor": 228,
+ "kuvasz": 222,
+ "lab coat, laboratory coat": 617,
+ "lacewing, lacewing fly": 318,
+ "ladle": 618,
+ "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
+ "lakeside, lakeshore": 975,
+ "lampshade, lamp shade": 619,
+ "langur": 374,
+ "laptop, laptop computer": 620,
+ "lawn mower, mower": 621,
+ "leaf beetle, chrysomelid": 304,
+ "leafhopper": 317,
+ "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
+ "lemon": 951,
+ "lens cap, lens cover": 622,
+ "leopard, Panthera pardus": 288,
+ "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
+ "letter opener, paper knife, paperknife": 623,
+ "library": 624,
+ "lifeboat": 625,
+ "lighter, light, igniter, ignitor": 626,
+ "limousine, limo": 627,
+ "limpkin, Aramus pictus": 135,
+ "liner, ocean liner": 628,
+ "lion, king of beasts, Panthera leo": 291,
+ "lionfish": 396,
+ "lipstick, lip rouge": 629,
+ "little blue heron, Egretta caerulea": 131,
+ "llama": 355,
+ "loggerhead, loggerhead turtle, Caretta caretta": 33,
+ "long-horned beetle, longicorn, longicorn beetle": 303,
+ "lorikeet": 90,
+ "lotion": 631,
+ "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
+ "loupe, jeweler's loupe": 633,
+ "lumbermill, sawmill": 634,
+ "lycaenid, lycaenid butterfly": 326,
+ "lynx, catamount": 287,
+ "macaque": 373,
+ "macaw": 88,
+ "magnetic compass": 635,
+ "magpie": 18,
+ "mailbag, postbag": 636,
+ "mailbox, letter box": 637,
+ "maillot": 638,
+ "maillot, tank suit": 639,
+ "malamute, malemute, Alaskan malamute": 249,
+ "malinois": 225,
+ "manhole cover": 640,
+ "mantis, mantid": 315,
+ "maraca": 641,
+ "marimba, xylophone": 642,
+ "marmoset": 377,
+ "marmot": 336,
+ "mashed potato": 935,
+ "mask": 643,
+ "matchstick": 644,
+ "maypole": 645,
+ "maze, labyrinth": 646,
+ "measuring cup": 647,
+ "meat loaf, meatloaf": 962,
+ "medicine chest, medicine cabinet": 648,
+ "meerkat, mierkat": 299,
+ "megalith, megalithic structure": 649,
+ "menu": 922,
+ "microphone, mike": 650,
+ "microwave, microwave oven": 651,
+ "military uniform": 652,
+ "milk can": 653,
+ "miniature pinscher": 237,
+ "miniature poodle": 266,
+ "miniature schnauzer": 196,
+ "minibus": 654,
+ "miniskirt, mini": 655,
+ "minivan": 656,
+ "mink": 357,
+ "missile": 657,
+ "mitten": 658,
+ "mixing bowl": 659,
+ "mobile home, manufactured home": 660,
+ "modem": 662,
+ "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
+ "monastery": 663,
+ "mongoose": 298,
+ "monitor": 664,
+ "moped": 665,
+ "mortar": 666,
+ "mortarboard": 667,
+ "mosque": 668,
+ "mosquito net": 669,
+ "motor scooter, scooter": 670,
+ "mountain bike, all-terrain bike, off-roader": 671,
+ "mountain tent": 672,
+ "mouse, computer mouse": 673,
+ "mousetrap": 674,
+ "moving van": 675,
+ "mud turtle": 35,
+ "mushroom": 947,
+ "muzzle": 676,
+ "nail": 677,
+ "neck brace": 678,
+ "necklace": 679,
+ "nematode, nematode worm, roundworm": 111,
+ "night snake, Hypsiglena torquata": 60,
+ "nipple": 680,
+ "notebook, notebook computer": 681,
+ "obelisk": 682,
+ "oboe, hautboy, hautbois": 683,
+ "ocarina, sweet potato": 684,
+ "odometer, hodometer, mileometer, milometer": 685,
+ "oil filter": 686,
+ "orange": 950,
+ "orangutan, orang, orangutang, Pongo pygmaeus": 365,
+ "organ, pipe organ": 687,
+ "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
+ "ostrich, Struthio camelus": 9,
+ "otter": 360,
+ "otterhound, otter hound": 175,
+ "overskirt": 689,
+ "ox": 345,
+ "oxcart": 690,
+ "oxygen mask": 691,
+ "oystercatcher, oyster catcher": 143,
+ "packet": 692,
+ "paddle, boat paddle": 693,
+ "paddlewheel, paddle wheel": 694,
+ "padlock": 695,
+ "paintbrush": 696,
+ "pajama, pyjama, pj's, jammies": 697,
+ "palace": 698,
+ "panpipe, pandean pipe, syrinx": 699,
+ "paper towel": 700,
+ "papillon": 157,
+ "parachute, chute": 701,
+ "parallel bars, bars": 702,
+ "park bench": 703,
+ "parking meter": 704,
+ "partridge": 86,
+ "passenger car, coach, carriage": 705,
+ "patas, hussar monkey, Erythrocebus patas": 371,
+ "patio, terrace": 706,
+ "pay-phone, pay-station": 707,
+ "peacock": 84,
+ "pedestal, plinth, footstall": 708,
+ "pelican": 144,
+ "pencil box, pencil case": 709,
+ "pencil sharpener": 710,
+ "perfume, essence": 711,
+ "photocopier": 713,
+ "pick, plectrum, plectron": 714,
+ "pickelhaube": 715,
+ "picket fence, paling": 716,
+ "pickup, pickup truck": 717,
+ "pier": 718,
+ "piggy bank, penny bank": 719,
+ "pill bottle": 720,
+ "pillow": 721,
+ "pineapple, ananas": 953,
+ "ping-pong ball": 722,
+ "pinwheel": 723,
+ "pirate, pirate ship": 724,
+ "pitcher, ewer": 725,
+ "pizza, pizza pie": 963,
+ "plane, carpenter's plane, woodworking plane": 726,
+ "planetarium": 727,
+ "plastic bag": 728,
+ "plate": 923,
+ "plate rack": 729,
+ "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
+ "plow, plough": 730,
+ "plunger, plumber's helper": 731,
+ "pole": 733,
+ "polecat, fitch, foulmart, foumart, Mustela putorius": 358,
+ "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
+ "pomegranate": 957,
+ "poncho": 735,
+ "pool table, billiard table, snooker table": 736,
+ "pop bottle, soda bottle": 737,
+ "porcupine, hedgehog": 334,
+ "pot, flowerpot": 738,
+ "potpie": 964,
+ "potter's wheel": 739,
+ "power drill": 740,
+ "prairie chicken, prairie grouse, prairie fowl": 83,
+ "prayer rug, prayer mat": 741,
+ "pretzel": 932,
+ "printer": 742,
+ "prison, prison house": 743,
+ "proboscis monkey, Nasalis larvatus": 376,
+ "projectile, missile": 744,
+ "projector": 745,
+ "promontory, headland, head, foreland": 976,
+ "ptarmigan": 81,
+ "puck, hockey puck": 746,
+ "puffer, pufferfish, blowfish, globefish": 397,
+ "pug, pug-dog": 254,
+ "punching bag, punch bag, punching ball, punchball": 747,
+ "purse": 748,
+ "quail": 85,
+ "quill, quill pen": 749,
+ "quilt, comforter, comfort, puff": 750,
+ "racer, race car, racing car": 751,
+ "racket, racquet": 752,
+ "radiator": 753,
+ "radio telescope, radio reflector": 755,
+ "radio, wireless": 754,
+ "rain barrel": 756,
+ "ram, tup": 348,
+ "rapeseed": 984,
+ "recreational vehicle, RV, R.V.": 757,
+ "red fox, Vulpes vulpes": 277,
+ "red wine": 966,
+ "red wolf, maned wolf, Canis rufus, Canis niger": 271,
+ "red-backed sandpiper, dunlin, Erolia alpina": 140,
+ "red-breasted merganser, Mergus serrator": 98,
+ "redbone": 168,
+ "redshank, Tringa totanus": 141,
+ "reel": 758,
+ "reflex camera": 759,
+ "refrigerator, icebox": 760,
+ "remote control, remote": 761,
+ "restaurant, eating house, eating place, eatery": 762,
+ "revolver, six-gun, six-shooter": 763,
+ "rhinoceros beetle": 306,
+ "rifle": 764,
+ "ringlet, ringlet butterfly": 322,
+ "ringneck snake, ring-necked snake, ring snake": 53,
+ "robin, American robin, Turdus migratorius": 15,
+ "rock beauty, Holocanthus tricolor": 392,
+ "rock crab, Cancer irroratus": 119,
+ "rock python, rock snake, Python sebae": 62,
+ "rocking chair, rocker": 765,
+ "rotisserie": 766,
+ "rubber eraser, rubber, pencil eraser": 767,
+ "ruddy turnstone, Arenaria interpres": 139,
+ "ruffed grouse, partridge, Bonasa umbellus": 82,
+ "rugby ball": 768,
+ "rule, ruler": 769,
+ "running shoe": 770,
+ "safe": 771,
+ "safety pin": 772,
+ "saltshaker, salt shaker": 773,
+ "sandal": 774,
+ "sandbar, sand bar": 977,
+ "sarong": 775,
+ "sax, saxophone": 776,
+ "scabbard": 777,
+ "scale, weighing machine": 778,
+ "schipperke": 223,
+ "school bus": 779,
+ "schooner": 780,
+ "scoreboard": 781,
+ "scorpion": 71,
+ "screen, CRT screen": 782,
+ "screw": 783,
+ "screwdriver": 784,
+ "scuba diver": 983,
+ "sea anemone, anemone": 108,
+ "sea cucumber, holothurian": 329,
+ "sea lion": 150,
+ "sea slug, nudibranch": 115,
+ "sea snake": 65,
+ "sea urchin": 328,
+ "seashore, coast, seacoast, sea-coast": 978,
+ "seat belt, seatbelt": 785,
+ "sewing machine": 786,
+ "shield, buckler": 787,
+ "shoe shop, shoe-shop, shoe store": 788,
+ "shoji": 789,
+ "shopping basket": 790,
+ "shopping cart": 791,
+ "shovel": 792,
+ "shower cap": 793,
+ "shower curtain": 794,
+ "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
+ "sidewinder, horned rattlesnake, Crotalus cerastes": 68,
+ "silky terrier, Sydney silky": 201,
+ "ski": 795,
+ "ski mask": 796,
+ "skunk, polecat, wood pussy": 361,
+ "sleeping bag": 797,
+ "slide rule, slipstick": 798,
+ "sliding door": 799,
+ "slot, one-armed bandit": 800,
+ "sloth bear, Melursus ursinus, Ursus ursinus": 297,
+ "slug": 114,
+ "snail": 113,
+ "snorkel": 801,
+ "snow leopard, ounce, Panthera uncia": 289,
+ "snowmobile": 802,
+ "snowplow, snowplough": 803,
+ "soap dispenser": 804,
+ "soccer ball": 805,
+ "sock": 806,
+ "soft-coated wheaten terrier": 202,
+ "solar dish, solar collector, solar furnace": 807,
+ "sombrero": 808,
+ "sorrel": 339,
+ "soup bowl": 809,
+ "space bar": 810,
+ "space heater": 811,
+ "space shuttle": 812,
+ "spaghetti squash": 940,
+ "spatula": 813,
+ "speedboat": 814,
+ "spider monkey, Ateles geoffroyi": 381,
+ "spider web, spider's web": 815,
+ "spindle": 816,
+ "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
+ "spoonbill": 129,
+ "sports car, sport car": 817,
+ "spotlight, spot": 818,
+ "spotted salamander, Ambystoma maculatum": 28,
+ "squirrel monkey, Saimiri sciureus": 382,
+ "stage": 819,
+ "standard poodle": 267,
+ "standard schnauzer": 198,
+ "starfish, sea star": 327,
+ "steam locomotive": 820,
+ "steel arch bridge": 821,
+ "steel drum": 822,
+ "stethoscope": 823,
+ "stingray": 6,
+ "stinkhorn, carrion fungus": 994,
+ "stole": 824,
+ "stone wall": 825,
+ "stopwatch, stop watch": 826,
+ "stove": 827,
+ "strainer": 828,
+ "strawberry": 949,
+ "street sign": 919,
+ "streetcar, tram, tramcar, trolley, trolley car": 829,
+ "stretcher": 830,
+ "studio couch, day bed": 831,
+ "stupa, tope": 832,
+ "sturgeon": 394,
+ "submarine, pigboat, sub, U-boat": 833,
+ "suit, suit of clothes": 834,
+ "sulphur butterfly, sulfur butterfly": 325,
+ "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
+ "sundial": 835,
+ "sunglass": 836,
+ "sunglasses, dark glasses, shades": 837,
+ "sunscreen, sunblock, sun blocker": 838,
+ "suspension bridge": 839,
+ "swab, swob, mop": 840,
+ "sweatshirt": 841,
+ "swimming trunks, bathing trunks": 842,
+ "swing": 843,
+ "switch, electric switch, electrical switch": 844,
+ "syringe": 845,
+ "tabby, tabby cat": 281,
+ "table lamp": 846,
+ "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
+ "tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
+ "tape player": 848,
+ "tarantula": 76,
+ "teapot": 849,
+ "teddy, teddy bear": 850,
+ "television, television system": 851,
+ "tench, Tinca tinca": 0,
+ "tennis ball": 852,
+ "terrapin": 36,
+ "thatch, thatched roof": 853,
+ "theater curtain, theatre curtain": 854,
+ "thimble": 855,
+ "three-toed sloth, ai, Bradypus tridactylus": 364,
+ "thresher, thrasher, threshing machine": 856,
+ "throne": 857,
+ "thunder snake, worm snake, Carphophis amoenus": 52,
+ "tick": 78,
+ "tiger beetle": 300,
+ "tiger cat": 282,
+ "tiger shark, Galeocerdo cuvieri": 3,
+ "tiger, Panthera tigris": 292,
+ "tile roof": 858,
+ "timber wolf, grey wolf, gray wolf, Canis lupus": 269,
+ "titi, titi monkey": 380,
+ "toaster": 859,
+ "tobacco shop, tobacconist shop, tobacconist": 860,
+ "toilet seat": 861,
+ "toilet tissue, toilet paper, bathroom tissue": 999,
+ "torch": 862,
+ "totem pole": 863,
+ "toucan": 96,
+ "tow truck, tow car, wrecker": 864,
+ "toy poodle": 265,
+ "toy terrier": 158,
+ "toyshop": 865,
+ "tractor": 866,
+ "traffic light, traffic signal, stoplight": 920,
+ "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
+ "tray": 868,
+ "tree frog, tree-frog": 31,
+ "trench coat": 869,
+ "triceratops": 51,
+ "tricycle, trike, velocipede": 870,
+ "trifle": 927,
+ "trilobite": 69,
+ "trimaran": 871,
+ "tripod": 872,
+ "triumphal arch": 873,
+ "trolleybus, trolley coach, trackless trolley": 874,
+ "trombone": 875,
+ "tub, vat": 876,
+ "turnstile": 877,
+ "tusker": 101,
+ "typewriter keyboard": 878,
+ "umbrella": 879,
+ "unicycle, monocycle": 880,
+ "upright, upright piano": 881,
+ "vacuum, vacuum cleaner": 882,
+ "valley, vale": 979,
+ "vase": 883,
+ "vault": 884,
+ "velvet": 885,
+ "vending machine": 886,
+ "vestment": 887,
+ "viaduct": 888,
+ "vine snake": 59,
+ "violin, fiddle": 889,
+ "vizsla, Hungarian pointer": 211,
+ "volcano": 980,
+ "volleyball": 890,
+ "vulture": 23,
+ "waffle iron": 891,
+ "walking stick, walkingstick, stick insect": 313,
+ "wall clock": 892,
+ "wallaby, brush kangaroo": 104,
+ "wallet, billfold, notecase, pocketbook": 893,
+ "wardrobe, closet, press": 894,
+ "warplane, military plane": 895,
+ "warthog": 343,
+ "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
+ "washer, automatic washer, washing machine": 897,
+ "water bottle": 898,
+ "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
+ "water jug": 899,
+ "water ouzel, dipper": 20,
+ "water snake": 58,
+ "water tower": 900,
+ "weasel": 356,
+ "web site, website, internet site, site": 916,
+ "weevil": 307,
+ "whippet": 172,
+ "whiptail, whiptail lizard": 41,
+ "whiskey jug": 901,
+ "whistle": 902,
+ "white stork, Ciconia ciconia": 127,
+ "white wolf, Arctic wolf, Canis lupus tundrarum": 270,
+ "wig": 903,
+ "wild boar, boar, Sus scrofa": 342,
+ "window screen": 904,
+ "window shade": 905,
+ "wine bottle": 907,
+ "wing": 908,
+ "wire-haired fox terrier": 188,
+ "wok": 909,
+ "wolf spider, hunting spider": 77,
+ "wombat": 106,
+ "wood rabbit, cottontail, cottontail rabbit": 330,
+ "wooden spoon": 910,
+ "wool, woolen, woollen": 911,
+ "worm fence, snake fence, snake-rail fence, Virginia fence": 912,
+ "wreck": 913,
+ "yawl": 914,
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
+ "yurt": 915,
+ "zebra": 340,
+ "zucchini, courgette": 939
+ },
+ "layer_norm_eps": 1e-05,
+ "length_penalty": 1.0,
+ "max_length": 20,
+ "min_length": 0,
+ "mlp_ratio": 4.0,
+ "model_type": "swin",
+ "no_repeat_ngram_size": 0,
+ "num_beam_groups": 1,
+ "num_beams": 1,
+ "num_channels": 3,
+ "num_heads": [
+ 3,
+ 6,
+ 12,
+ 24
+ ],
+ "num_layers": 4,
+ "num_return_sequences": 1,
+ "out_features": [
+ "stage1",
+ "stage2",
+ "stage3",
+ "stage4"
+ ],
+ "output_attentions": false,
+ "output_hidden_states": false,
+ "output_scores": false,
+ "pad_token_id": null,
+ "patch_size": 4,
+ "path_norm": true,
+ "prefix": null,
+ "problem_type": null,
+ "pruned_heads": {},
+ "qkv_bias": true,
+ "remove_invalid_values": false,
+ "repetition_penalty": 1.0,
+ "return_dict": true,
+ "return_dict_in_generate": false,
+ "sep_token_id": null,
+ "stage_names": [
+ "stem",
+ "stage1",
+ "stage2",
+ "stage3",
+ "stage4"
+ ],
+ "suppress_tokens": null,
+ "task_specific_params": null,
+ "temperature": 1.0,
+ "tf_legacy_loss": false,
+ "tie_encoder_decoder": false,
+ "tie_word_embeddings": true,
+ "tokenizer_class": null,
+ "top_k": 50,
+ "top_p": 1.0,
+ "torch_dtype": "float32",
+ "torchscript": false,
+ "transformers_version": "4.26.0.dev0",
+ "typical_p": 1.0,
+ "use_absolute_embeddings": false,
+ "use_bfloat16": false,
+ "window_size": 7
+ },
+ "class_weight": 2.0,
+ "common_stride": 4,
+ "decoder_layers": 10,
+ "dice_weight": 5.0,
+ "dim_feedforward": 2048,
+ "dropout": 0.0,
+ "encoder_feedforward_dim": 1024,
+ "encoder_layers": 6,
+ "enforce_input_proj": false,
+ "enforce_input_projection": false,
+ "feature_size": 256,
+ "feature_strides": [
+ 4,
+ 8,
+ 16,
+ 32
+ ],
+ "hidden_dim": 256,
+ "id2label": {
+ "0": "person",
+ "1": "bicycle",
+ "2": "car",
+ "3": "motorbike",
+ "4": "aeroplane",
+ "5": "bus",
+ "6": "train",
+ "7": "truck",
+ "8": "boat",
+ "9": "traffic light",
+ "10": "fire hydrant",
+ "11": "stop sign",
+ "12": "parking meter",
+ "13": "bench",
+ "14": "bird",
+ "15": "cat",
+ "16": "dog",
+ "17": "horse",
+ "18": "sheep",
+ "19": "cow",
+ "20": "elephant",
+ "21": "bear",
+ "22": "zebra",
+ "23": "giraffe",
+ "24": "backpack",
+ "25": "umbrella",
+ "26": "handbag",
+ "27": "tie",
+ "28": "suitcase",
+ "29": "frisbee",
+ "30": "skis",
+ "31": "snowboard",
+ "32": "sports ball",
+ "33": "kite",
+ "34": "baseball bat",
+ "35": "baseball glove",
+ "36": "skateboard",
+ "37": "surfboard",
+ "38": "tennis racket",
+ "39": "bottle",
+ "40": "wine glass",
+ "41": "cup",
+ "42": "fork",
+ "43": "knife",
+ "44": "spoon",
+ "45": "bowl",
+ "46": "banana",
+ "47": "apple",
+ "48": "sandwich",
+ "49": "orange",
+ "50": "broccoli",
+ "51": "carrot",
+ "52": "hot dog",
+ "53": "pizza",
+ "54": "donut",
+ "55": "cake",
+ "56": "chair",
+ "57": "sofa",
+ "58": "pottedplant",
+ "59": "bed",
+ "60": "diningtable",
+ "61": "toilet",
+ "62": "tvmonitor",
+ "63": "laptop",
+ "64": "mouse",
+ "65": "remote",
+ "66": "keyboard",
+ "67": "cell phone",
+ "68": "microwave",
+ "69": "oven",
+ "70": "toaster",
+ "71": "sink",
+ "72": "refrigerator",
+ "73": "book",
+ "74": "clock",
+ "75": "vase",
+ "76": "scissors",
+ "77": "teddy bear",
+ "78": "hair drier",
+ "79": "toothbrush"
+ },
+ "ignore_value": 255,
+ "importance_sample_ratio": 0.75,
+ "init_std": 0.02,
+ "init_xavier_std": 1.0,
+ "label2id": {
+ "aeroplane": 4,
+ "apple": 47,
+ "backpack": 24,
+ "banana": 46,
+ "baseball bat": 34,
+ "baseball glove": 35,
+ "bear": 21,
+ "bed": 59,
+ "bench": 13,
+ "bicycle": 1,
+ "bird": 14,
+ "boat": 8,
+ "book": 73,
+ "bottle": 39,
+ "bowl": 45,
+ "broccoli": 50,
+ "bus": 5,
+ "cake": 55,
+ "car": 2,
+ "carrot": 51,
+ "cat": 15,
+ "cell phone": 67,
+ "chair": 56,
+ "clock": 74,
+ "cow": 19,
+ "cup": 41,
+ "diningtable": 60,
+ "dog": 16,
+ "donut": 54,
+ "elephant": 20,
+ "fire hydrant": 10,
+ "fork": 42,
+ "frisbee": 29,
+ "giraffe": 23,
+ "hair drier": 78,
+ "handbag": 26,
+ "horse": 17,
+ "hot dog": 52,
+ "keyboard": 66,
+ "kite": 33,
+ "knife": 43,
+ "laptop": 63,
+ "microwave": 68,
+ "motorbike": 3,
+ "mouse": 64,
+ "orange": 49,
+ "oven": 69,
+ "parking meter": 12,
+ "person": 0,
+ "pizza": 53,
+ "pottedplant": 58,
+ "refrigerator": 72,
+ "remote": 65,
+ "sandwich": 48,
+ "scissors": 76,
+ "sheep": 18,
+ "sink": 71,
+ "skateboard": 36,
+ "skis": 30,
+ "snowboard": 31,
+ "sofa": 57,
+ "spoon": 44,
+ "sports ball": 32,
+ "stop sign": 11,
+ "suitcase": 28,
+ "surfboard": 37,
+ "teddy bear": 77,
+ "tennis racket": 38,
+ "tie": 27,
+ "toaster": 70,
+ "toilet": 61,
+ "toothbrush": 79,
+ "traffic light": 9,
+ "train": 6,
+ "truck": 7,
+ "tvmonitor": 62,
+ "umbrella": 25,
+ "vase": 75,
+ "wine glass": 40,
+ "zebra": 22
+ },
+ "mask_feature_size": 256,
+ "mask_weight": 5.0,
+ "model_type": "mask2former",
+ "no_object_weight": 0.1,
+ "num_attention_heads": 8,
+ "num_hidden_layers": 10,
+ "num_queries": 100,
+ "output_auxiliary_logits": null,
+ "oversample_ratio": 3.0,
+ "pre_norm": false,
+ "torch_dtype": "float32",
+ "train_num_points": 12544,
+ "transformers_version": null,
+ "use_auxiliary_loss": true
+}
diff --git a/external/human_matting/__init__.py b/external/human_matting/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56f64d6494cbd0375685400c6f35afcb5cd3c2b4
--- /dev/null
+++ b/external/human_matting/__init__.py
@@ -0,0 +1 @@
+from .matting_engine import StyleMatteEngine
diff --git a/external/human_matting/matting_engine.py b/external/human_matting/matting_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a833f550ccd582f5bcaac712c4aa1f428835a7a
--- /dev/null
+++ b/external/human_matting/matting_engine.py
@@ -0,0 +1,66 @@
+import os
+import torch
+import inspect
+import warnings
+import torchvision
+from .stylematte import StyleMatte
+
+class StyleMatteEngine(torch.nn.Module):
+ def __init__(self, device='cpu',human_matting_path='./pretrain_model/matting/stylematte_synth.pt'):
+ super().__init__()
+ self._device = device
+ self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ self._init_models(human_matting_path)
+
+ def _init_models(self,_ckpt_path):
+ # load dict
+ state_dict = torch.load(_ckpt_path, map_location='cpu')
+ # build model
+ model = StyleMatte()
+ model.load_state_dict(state_dict)
+ self.model = model.to(self._device).eval()
+
+ @torch.no_grad()
+ def forward(self, input_image, return_type='matting', background_rgb=1.0):
+ if not hasattr(self, 'model'):
+ self._init_models()
+ if input_image.max() > 2.0:
+ warnings.warn('Image should be normalized to [0, 1].')
+ _, ori_h, ori_w = input_image.shape
+ input_image = input_image.to(self._device).float()
+ image = input_image.clone()
+ # resize
+ if max(ori_h, ori_w) > 1024:
+ scale = 1024.0 / max(ori_h, ori_w)
+ resized_h, resized_w = int(ori_h * scale), int(ori_w * scale)
+ image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True)
+ else:
+ resized_h, resized_w = ori_h, ori_w
+ # padding
+ if resized_h % 8 != 0 or resized_w % 8 != 0:
+ image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect')
+ # normalize and forwarding
+ image = self.normalize(image)[None]
+ predict = self.model(image)[0]
+ # undo padding
+ predict = predict[:, -resized_h:, -resized_w:]
+ # undo resize
+ if resized_h != ori_h or resized_w != ori_w:
+ predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True)
+
+ if return_type == 'alpha':
+ return predict[0]
+ elif return_type == 'matting':
+ predict = predict.expand(3, -1, -1)
+ matting_image = input_image.clone()
+ background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb
+ matting_image = matting_image * predict + (1-predict) * background_rgb
+ return matting_image, predict[0]
+ elif return_type == 'all':
+ predict = predict.expand(3, -1, -1)
+ background_rgb = input_image.new_ones(input_image.shape) * background_rgb
+ foreground_image = input_image * predict + (1-predict) * background_rgb
+ background_image = input_image * (1-predict) + predict * background_rgb
+ return foreground_image, background_image
+ else:
+ raise NotImplementedError
diff --git a/external/human_matting/stylematte.py b/external/human_matting/stylematte.py
new file mode 100644
index 0000000000000000000000000000000000000000..db11ec79e69e7c95ca66fb2e4ae85235965cb6be
--- /dev/null
+++ b/external/human_matting/stylematte.py
@@ -0,0 +1,272 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers import Mask2FormerForUniversalSegmentation
+from transformers.models.mask2former.configuration_mask2former import Mask2FormerConfig
+
+class StyleMatte(nn.Module):
+ def __init__(self):
+ super(StyleMatte, self).__init__()
+ self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
+ config = Mask2FormerConfig.from_json_file('./configs/stylematte_config.json')
+ self.pixel_decoder = Mask2FormerForUniversalSegmentation(config).base_model.pixel_level_module
+ self.fgf = FastGuidedFilter(eps=1e-4)
+ self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
+
+ def forward(self, image, normalize=False):
+ decoder_out = self.pixel_decoder(image)
+ decoder_states = list(decoder_out.decoder_hidden_states)
+ decoder_states.append(decoder_out.decoder_last_hidden_state)
+ out_pure = self.fpn(decoder_states)
+
+ image_lr = nn.functional.interpolate(image.mean(1, keepdim=True),
+ scale_factor=0.25,
+ mode='bicubic',
+ align_corners=True
+ )
+ out = self.conv(out_pure)
+ out = self.fgf(image_lr, out, image.mean(1, keepdim=True))
+
+ return torch.sigmoid(out)
+
+ def get_training_params(self):
+ return list(self.fpn.parameters())+list(self.conv.parameters())
+
+
+def conv2d_relu(input_filters, output_filters, kernel_size=3, bias=True):
+ return nn.Sequential(
+ nn.Conv2d(input_filters, output_filters,
+ kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.BatchNorm2d(output_filters)
+ )
+
+
+def up_and_add(x, y):
+ return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
+
+
+class FPN_fuse(nn.Module):
+ def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
+ super(FPN_fuse, self).__init__()
+ assert feature_channels[0] == fpn_out
+ self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
+ for ft_size in feature_channels[1:]])
+ self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
+ * (len(feature_channels)-1))
+ self.conv_fusion = nn.Sequential(
+ nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3,
+ padding=1, bias=False),
+ nn.BatchNorm2d(fpn_out),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, features):
+
+ features[:-1] = [conv1x1(feature) for feature,
+ conv1x1 in zip(features[:-1], self.conv1x1)]
+ feature = up_and_add(self.smooth_conv[0](features[0]), features[1])
+ feature = up_and_add(self.smooth_conv[1](feature), features[2])
+ feature = up_and_add(self.smooth_conv[2](feature), features[3])
+
+ H, W = features[-1].size(2), features[-1].size(3)
+ x = [feature, features[-1]]
+ x = [F.interpolate(x_el, size=(H, W), mode='bilinear',
+ align_corners=True) for x_el in x]
+
+ x = self.conv_fusion(torch.cat(x, dim=1))
+
+ return x
+
+
+class PSPModule(nn.Module):
+ # In the original inmplementation they use precise RoI pooling
+ # Instead of using adaptative average pooling
+ def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
+ super(PSPModule, self).__init__()
+ out_channels = in_channels // len(bin_sizes)
+ self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
+ for b_s in bin_sizes])
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
+ kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(in_channels),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1)
+ )
+
+ def _make_stages(self, in_channels, out_channels, bin_sz):
+ prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ bn = nn.BatchNorm2d(out_channels)
+ relu = nn.ReLU(inplace=True)
+ return nn.Sequential(prior, conv, bn, relu)
+
+ def forward(self, features):
+ h, w = features.size()[2], features.size()[3]
+ pyramids = [features]
+ pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
+ align_corners=True) for stage in self.stages])
+ output = self.bottleneck(torch.cat(pyramids, dim=1))
+ return output
+
+
+class GuidedFilter(nn.Module):
+ def __init__(self, r, eps=1e-8):
+ super(GuidedFilter, self).__init__()
+
+ self.r = r
+ self.eps = eps
+ self.boxfilter = BoxFilter(r)
+
+ def forward(self, x, y):
+ n_x, c_x, h_x, w_x = x.size()
+ n_y, c_y, h_y, w_y = y.size()
+
+ assert n_x == n_y
+ assert c_x == 1 or c_x == c_y
+ assert h_x == h_y and w_x == w_y
+ assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
+
+ # N
+ N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
+
+ # mean_x
+ mean_x = self.boxfilter(x) / N
+ # mean_y
+ mean_y = self.boxfilter(y) / N
+ # cov_xy
+ cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
+ # var_x
+ var_x = self.boxfilter(x * x) / N - mean_x * mean_x
+
+ # A
+ A = cov_xy / (var_x + self.eps)
+ # b
+ b = mean_y - A * mean_x
+
+ # mean_A; mean_b
+ mean_A = self.boxfilter(A) / N
+ mean_b = self.boxfilter(b) / N
+
+ return mean_A * x + mean_b
+
+
+class FastGuidedFilter(nn.Module):
+ def __init__(self, r=1, eps=1e-8):
+ super(FastGuidedFilter, self).__init__()
+
+ self.r = r
+ self.eps = eps
+ self.boxfilter = BoxFilter(r)
+
+ def forward(self, lr_x, lr_y, hr_x):
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
+
+ assert n_lrx == n_lry and n_lry == n_hrx
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
+ assert h_lrx == h_lry and w_lrx == w_lry
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
+
+ # N
+ N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
+
+ # mean_x
+ mean_x = self.boxfilter(lr_x) / N
+ # mean_y
+ mean_y = self.boxfilter(lr_y) / N
+ # cov_xy
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
+ # var_x
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
+
+ # A
+ A = cov_xy / (var_x + self.eps)
+ # b
+ b = mean_y - A * mean_x
+
+ # mean_A; mean_b
+ mean_A = F.interpolate(
+ A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
+ mean_b = F.interpolate(
+ b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
+
+ return mean_A*hr_x+mean_b
+
+
+class DeepGuidedFilterRefiner(nn.Module):
+ def __init__(self, hid_channels=16):
+ super().__init__()
+ self.box_filter = nn.Conv2d(
+ 4, 4, kernel_size=3, padding=1, bias=False, groups=4)
+ self.box_filter.weight.data[...] = 1 / 9
+ self.conv = nn.Sequential(
+ nn.Conv2d(4 * 2 + hid_channels, hid_channels,
+ kernel_size=1, bias=False),
+ nn.BatchNorm2d(hid_channels),
+ nn.ReLU(True),
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(hid_channels),
+ nn.ReLU(True),
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
+ )
+
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
+
+ mean_x = self.box_filter(base_x)
+ mean_y = self.box_filter(base_y)
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
+
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
+ b = mean_y - A * mean_x
+
+ H, W = fine_src.shape[2:]
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
+
+ out = A * fine_x + b
+ fgr, pha = out.split([3, 1], dim=1)
+ return fgr, pha
+
+
+def diff_x(input, r):
+ assert input.dim() == 4
+
+ left = input[:, :, r:2 * r + 1]
+ middle = input[:, :, 2 * r + 1:] - input[:, :, :-2 * r - 1]
+ right = input[:, :, -1:] - input[:, :, -2 * r - 1: -r - 1]
+
+ output = torch.cat([left, middle, right], dim=2)
+
+ return output
+
+
+def diff_y(input, r):
+ assert input.dim() == 4
+
+ left = input[:, :, :, r:2 * r + 1]
+ middle = input[:, :, :, 2 * r + 1:] - input[:, :, :, :-2 * r - 1]
+ right = input[:, :, :, -1:] - input[:, :, :, -2 * r - 1: -r - 1]
+
+ output = torch.cat([left, middle, right], dim=3)
+
+ return output
+
+
+class BoxFilter(nn.Module):
+ def __init__(self, r):
+ super(BoxFilter, self).__init__()
+
+ self.r = r
+
+ def forward(self, x):
+ assert x.dim() == 4
+
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
diff --git a/external/landmark_detection/FaceBoxesV2/__init__.py b/external/landmark_detection/FaceBoxesV2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..336a4de8633fbb890fa88973efd234778fb5dee5
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/__init__.py
@@ -0,0 +1,2 @@
+from . import detector
+from . import faceboxes_detector
\ No newline at end of file
diff --git a/external/landmark_detection/FaceBoxesV2/detector.py b/external/landmark_detection/FaceBoxesV2/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb9c8fe988e05bb5d72103d6699ca8eac6de678f
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/detector.py
@@ -0,0 +1,39 @@
+import cv2
+
+class Detector(object):
+ def __init__(self, model_arch, model_weights):
+ self.model_arch = model_arch
+ self.model_weights = model_weights
+
+ def detect(self, image, thresh):
+ raise NotImplementedError
+
+ def crop(self, image, detections):
+ crops = []
+ for det in detections:
+ xmin = max(det[2], 0)
+ ymin = max(det[3], 0)
+ width = det[4]
+ height = det[5]
+ xmax = min(xmin+width, image.shape[1])
+ ymax = min(ymin+height, image.shape[0])
+ cut = image[ymin:ymax, xmin:xmax,:]
+ crops.append(cut)
+
+ return crops
+
+ def draw(self, image, detections, im_scale=None):
+ if im_scale is not None:
+ image = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
+ detections = [[det[0],det[1],int(det[2]*im_scale),int(det[3]*im_scale),int(det[4]*im_scale),int(det[5]*im_scale)] for det in detections]
+
+ for det in detections:
+ xmin = det[2]
+ ymin = det[3]
+ width = det[4]
+ height = det[5]
+ xmax = xmin + width
+ ymax = ymin + height
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
+
+ return image
diff --git a/external/landmark_detection/FaceBoxesV2/faceboxes_detector.py b/external/landmark_detection/FaceBoxesV2/faceboxes_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..04c9e8fe069494c4d06072e34ff17168e9e493a7
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/faceboxes_detector.py
@@ -0,0 +1,97 @@
+from .detector import Detector
+import cv2, os
+import numpy as np
+import torch
+import torch.nn as nn
+from .utils.config import cfg
+from .utils.prior_box import PriorBox
+from .utils.nms_wrapper import nms
+from .utils.faceboxes import FaceBoxesV2
+from .utils.box_utils import decode
+import time
+
+class FaceBoxesDetector(Detector):
+ def __init__(self, model_arch, model_weights, use_gpu, device):
+ super().__init__(model_arch, model_weights)
+ self.name = 'FaceBoxesDetector'
+ self.net = FaceBoxesV2(phase='test', size=None, num_classes=2) # initialize detector
+ self.use_gpu = use_gpu
+ self.device = device
+
+ state_dict = torch.load(self.model_weights, map_location=self.device)
+ # create new OrderedDict that does not contain `module.`
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] # remove `module.`
+ new_state_dict[name] = v
+ # load params
+ self.net.load_state_dict(new_state_dict)
+ self.net = self.net.to(self.device)
+ self.net.eval()
+
+
+ def detect(self, image, thresh=0.6, im_scale=None):
+ # auto resize for large images
+ if im_scale is None:
+ height, width, _ = image.shape
+ if min(height, width) > 600:
+ im_scale = 600. / min(height, width)
+ else:
+ im_scale = 1
+ image_scale = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
+
+ scale = torch.Tensor([image_scale.shape[1], image_scale.shape[0], image_scale.shape[1], image_scale.shape[0]])
+ image_scale = torch.from_numpy(image_scale.transpose(2,0,1)).to(self.device).int()
+ mean_tmp = torch.IntTensor([104, 117, 123]).to(self.device)
+ mean_tmp = mean_tmp.unsqueeze(1).unsqueeze(2)
+ image_scale -= mean_tmp
+ image_scale = image_scale.float().unsqueeze(0)
+ scale = scale.to(self.device)
+
+ with torch.no_grad():
+ out = self.net(image_scale)
+ #priorbox = PriorBox(cfg, out[2], (image_scale.size()[2], image_scale.size()[3]), phase='test')
+ priorbox = PriorBox(cfg, image_size=(image_scale.size()[2], image_scale.size()[3]))
+ priors = priorbox.forward()
+ priors = priors.to(self.device)
+ loc, conf = out
+ prior_data = priors.data
+ boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
+ boxes = boxes * scale
+ boxes = boxes.cpu().numpy()
+ scores = conf.data.cpu().numpy()[:, 1]
+
+ # ignore low scores
+ inds = np.where(scores > thresh)[0]
+ boxes = boxes[inds]
+ scores = scores[inds]
+
+ # keep top-K before NMS
+ order = scores.argsort()[::-1][:5000]
+ boxes = boxes[order]
+ scores = scores[order]
+
+ # do NMS
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = nms(dets, 0.3)
+ dets = dets[keep, :]
+
+ dets = dets[:750, :]
+ detections_scale = []
+ for i in range(dets.shape[0]):
+ xmin = int(dets[i][0])
+ ymin = int(dets[i][1])
+ xmax = int(dets[i][2])
+ ymax = int(dets[i][3])
+ score = dets[i][4]
+ width = xmax - xmin
+ height = ymax - ymin
+ detections_scale.append(['face', score, xmin, ymin, width, height])
+
+ # adapt bboxes to the original image size
+ if len(detections_scale) > 0:
+ detections_scale = [[det[0],det[1],int(det[2]/im_scale),int(det[3]/im_scale),int(det[4]/im_scale),int(det[5]/im_scale)] for det in detections_scale]
+
+ return detections_scale, im_scale
+
diff --git a/external/landmark_detection/FaceBoxesV2/utils/__init__.py b/external/landmark_detection/FaceBoxesV2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/external/landmark_detection/FaceBoxesV2/utils/box_utils.py b/external/landmark_detection/FaceBoxesV2/utils/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4797f1d7498cc35499c9b86a35c0754eb16e5a60
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/box_utils.py
@@ -0,0 +1,276 @@
+import torch
+import numpy as np
+
+
+def point_form(boxes):
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
+
+
+def center_size(boxes):
+ """ Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
+
+
+def intersect(box_a, box_b):
+ """ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when mathing boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location and 2)confidence preds.
+ """
+ # jaccard index
+ overlaps = jaccard(
+ truths,
+ point_form(priors)
+ )
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)):
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4]
+ conf = labels[best_truth_idx] # Shape: [num_priors]
+ conf[best_truth_overlap < threshold] = 0 # label as background
+ loc = encode(matches, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w*h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter/union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
+
+
diff --git a/external/landmark_detection/FaceBoxesV2/utils/build.py b/external/landmark_detection/FaceBoxesV2/utils/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d4fb495db46eb5eb0b311d9bfbd5e111d49c56
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/build.py
@@ -0,0 +1,57 @@
+# coding: utf-8
+
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import os
+from os.path import join as pjoin
+import numpy as np
+from distutils.core import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+
+
+def find_in_path(name, path):
+ "Find a file in a search path"
+ # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
+ for dir in path.split(os.pathsep):
+ binpath = pjoin(dir, name)
+ if os.path.exists(binpath):
+ return os.path.abspath(binpath)
+ return None
+
+
+# Obtain the numpy include directory. This logic works across numpy versions.
+try:
+ numpy_include = np.get_include()
+except AttributeError:
+ numpy_include = np.get_numpy_include()
+
+
+# run the customize_compiler
+class custom_build_ext(build_ext):
+ def build_extensions(self):
+ # customize_compiler_for_nvcc(self.compiler)
+ build_ext.build_extensions(self)
+
+
+ext_modules = [
+ Extension(
+ "nms.cpu_nms",
+ ["nms/cpu_nms.pyx"],
+ # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
+ extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
+ include_dirs=[numpy_include]
+ )
+]
+
+setup(
+ name='mot_utils',
+ ext_modules=ext_modules,
+ # inject our custom trigger
+ cmdclass={'build_ext': custom_build_ext},
+)
diff --git a/external/landmark_detection/FaceBoxesV2/utils/config.py b/external/landmark_detection/FaceBoxesV2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..527c8b3754fbc6e1187e4539bf5075cd0cea4952
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/config.py
@@ -0,0 +1,14 @@
+# config.py
+
+cfg = {
+ 'name': 'FaceBoxes',
+ #'min_dim': 1024,
+ #'feature_maps': [[32, 32], [16, 16], [8, 8]],
+ # 'aspect_ratios': [[1], [1], [1]],
+ 'min_sizes': [[32, 64, 128], [256], [512]],
+ 'steps': [32, 64, 128],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True
+}
diff --git a/external/landmark_detection/FaceBoxesV2/utils/faceboxes.py b/external/landmark_detection/FaceBoxesV2/utils/faceboxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ae4d31a7dfde983da5459c169a30e5a11ccdd7a
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/faceboxes.py
@@ -0,0 +1,239 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
+
+
+class Inception(nn.Module):
+
+ def __init__(self):
+ super(Inception, self).__init__()
+ self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)
+ self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)
+ self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)
+ self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)
+ self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)
+ self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)
+ self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch1x1_2 = self.branch1x1_2(branch1x1_pool)
+
+ branch3x3_reduce = self.branch3x3_reduce(x)
+ branch3x3 = self.branch3x3(branch3x3_reduce)
+
+ branch3x3_reduce_2 = self.branch3x3_reduce_2(x)
+ branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)
+ branch3x3_3 = self.branch3x3_3(branch3x3_2)
+
+ outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3]
+ return torch.cat(outputs, 1)
+
+
+class CRelu(nn.Module):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(CRelu, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = torch.cat([x, -x], 1)
+ x = F.relu(x, inplace=True)
+ return x
+
+
+class FaceBoxes(nn.Module):
+
+ def __init__(self, phase, size, num_classes):
+ super(FaceBoxes, self).__init__()
+ self.phase = phase
+ self.num_classes = num_classes
+ self.size = size
+
+ self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)
+ self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)
+
+ self.inception1 = Inception()
+ self.inception2 = Inception()
+ self.inception3 = Inception()
+
+ self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
+ self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
+ self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.loc, self.conf = self.multibox(self.num_classes)
+
+ if self.phase == 'test':
+ self.softmax = nn.Softmax(dim=-1)
+
+ if self.phase == 'train':
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ nn.init.xavier_normal_(m.weight.data)
+ m.bias.data.fill_(0.02)
+ else:
+ m.weight.data.normal_(0, 0.01)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def multibox(self, num_classes):
+ loc_layers = []
+ conf_layers = []
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
+
+ def forward(self, x):
+
+ detection_sources = list()
+ loc = list()
+ conf = list()
+
+ x = self.conv1(x)
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+ x = self.conv2(x)
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+ x = self.inception1(x)
+ x = self.inception2(x)
+ x = self.inception3(x)
+ detection_sources.append(x)
+
+ x = self.conv3_1(x)
+ x = self.conv3_2(x)
+ detection_sources.append(x)
+
+ x = self.conv4_1(x)
+ x = self.conv4_2(x)
+ detection_sources.append(x)
+
+ for (x, l, c) in zip(detection_sources, self.loc, self.conf):
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
+
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
+
+ if self.phase == "test":
+ output = (loc.view(loc.size(0), -1, 4),
+ self.softmax(conf.view(conf.size(0), -1, self.num_classes)))
+ else:
+ output = (loc.view(loc.size(0), -1, 4),
+ conf.view(conf.size(0), -1, self.num_classes))
+
+ return output
+
+class FaceBoxesV2(nn.Module):
+
+ def __init__(self, phase, size, num_classes):
+ super(FaceBoxesV2, self).__init__()
+ self.phase = phase
+ self.num_classes = num_classes
+ self.size = size
+
+ self.conv1 = BasicConv2d(3, 8, kernel_size=3, stride=2, padding=1)
+ self.conv2 = BasicConv2d(8, 16, kernel_size=3, stride=2, padding=1)
+ self.conv3 = BasicConv2d(16, 32, kernel_size=3, stride=2, padding=1)
+ self.conv4 = BasicConv2d(32, 64, kernel_size=3, stride=2, padding=1)
+ self.conv5 = BasicConv2d(64, 128, kernel_size=3, stride=2, padding=1)
+
+ self.inception1 = Inception()
+ self.inception2 = Inception()
+ self.inception3 = Inception()
+
+ self.conv6_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
+ self.conv6_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.conv7_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
+ self.conv7_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.loc, self.conf = self.multibox(self.num_classes)
+
+ if self.phase == 'test':
+ self.softmax = nn.Softmax(dim=-1)
+
+ if self.phase == 'train':
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ nn.init.xavier_normal_(m.weight.data)
+ m.bias.data.fill_(0.02)
+ else:
+ m.weight.data.normal_(0, 0.01)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def multibox(self, num_classes):
+ loc_layers = []
+ conf_layers = []
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
+
+ def forward(self, x):
+
+ sources = list()
+ loc = list()
+ conf = list()
+
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.conv5(x)
+ x = self.inception1(x)
+ x = self.inception2(x)
+ x = self.inception3(x)
+ sources.append(x)
+ x = self.conv6_1(x)
+ x = self.conv6_2(x)
+ sources.append(x)
+ x = self.conv7_1(x)
+ x = self.conv7_2(x)
+ sources.append(x)
+
+ for (x, l, c) in zip(sources, self.loc, self.conf):
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
+
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
+
+ if self.phase == "test":
+ output = (loc.view(loc.size(0), -1, 4),
+ self.softmax(conf.view(-1, self.num_classes)))
+ else:
+ output = (loc.view(loc.size(0), -1, 4),
+ conf.view(conf.size(0), -1, self.num_classes))
+
+ return output
diff --git a/external/landmark_detection/FaceBoxesV2/utils/make.sh b/external/landmark_detection/FaceBoxesV2/utils/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9693ed462e516432e100cc743ae3a1417aa12c41
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/make.sh
@@ -0,0 +1,3 @@
+#!/usr/bin/env bash
+python3 build.py build_ext --inplace
+
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py b/external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c b/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c
new file mode 100644
index 0000000000000000000000000000000000000000..a96bf32bf15040b52a08e296f41b7c547067d575
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c
@@ -0,0 +1,14164 @@
+/* Generated by Cython 3.0.12 */
+
+/* BEGIN: Cython Metadata
+{
+ "distutils": {
+ "depends": [
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include/numpy/arrayobject.h",
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include/numpy/arrayscalars.h",
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include/numpy/ndarrayobject.h",
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include/numpy/ndarraytypes.h",
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include/numpy/ufuncobject.h"
+ ],
+ "extra_compile_args": [
+ "-Wno-cpp",
+ "-Wno-unused-function"
+ ],
+ "include_dirs": [
+ "/home/yisheng/Data16T/conda_envs/gagavatar/lib/python3.10/site-packages/numpy/core/include"
+ ],
+ "name": "nms.cpu_nms",
+ "sources": [
+ "nms/cpu_nms.pyx"
+ ]
+ },
+ "module_name": "nms.cpu_nms"
+}
+END: Cython Metadata */
+
+#ifndef PY_SSIZE_T_CLEAN
+#define PY_SSIZE_T_CLEAN
+#endif /* PY_SSIZE_T_CLEAN */
+#if defined(CYTHON_LIMITED_API) && 0
+ #ifndef Py_LIMITED_API
+ #if CYTHON_LIMITED_API+0 > 0x03030000
+ #define Py_LIMITED_API CYTHON_LIMITED_API
+ #else
+ #define Py_LIMITED_API 0x03030000
+ #endif
+ #endif
+#endif
+
+#include "Python.h"
+#ifndef Py_PYTHON_H
+ #error Python headers needed to compile C extensions, please install development version of Python.
+#elif PY_VERSION_HEX < 0x02070000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000)
+ #error Cython requires Python 2.7+ or Python 3.3+.
+#else
+#if defined(CYTHON_LIMITED_API) && CYTHON_LIMITED_API
+#define __PYX_EXTRA_ABI_MODULE_NAME "limited"
+#else
+#define __PYX_EXTRA_ABI_MODULE_NAME ""
+#endif
+#define CYTHON_ABI "3_0_12" __PYX_EXTRA_ABI_MODULE_NAME
+#define __PYX_ABI_MODULE_NAME "_cython_" CYTHON_ABI
+#define __PYX_TYPE_MODULE_PREFIX __PYX_ABI_MODULE_NAME "."
+#define CYTHON_HEX_VERSION 0x03000CF0
+#define CYTHON_FUTURE_DIVISION 1
+#include
+#ifndef offsetof
+ #define offsetof(type, member) ( (size_t) & ((type*)0) -> member )
+#endif
+#if !defined(_WIN32) && !defined(WIN32) && !defined(MS_WINDOWS)
+ #ifndef __stdcall
+ #define __stdcall
+ #endif
+ #ifndef __cdecl
+ #define __cdecl
+ #endif
+ #ifndef __fastcall
+ #define __fastcall
+ #endif
+#endif
+#ifndef DL_IMPORT
+ #define DL_IMPORT(t) t
+#endif
+#ifndef DL_EXPORT
+ #define DL_EXPORT(t) t
+#endif
+#define __PYX_COMMA ,
+#ifndef HAVE_LONG_LONG
+ #define HAVE_LONG_LONG
+#endif
+#ifndef PY_LONG_LONG
+ #define PY_LONG_LONG LONG_LONG
+#endif
+#ifndef Py_HUGE_VAL
+ #define Py_HUGE_VAL HUGE_VAL
+#endif
+#define __PYX_LIMITED_VERSION_HEX PY_VERSION_HEX
+#if defined(GRAALVM_PYTHON)
+ /* For very preliminary testing purposes. Most variables are set the same as PyPy.
+ The existence of this section does not imply that anything works or is even tested */
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_LIMITED_API 0
+ #define CYTHON_COMPILING_IN_GRAAL 1
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #undef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 0
+ #undef CYTHON_USE_TYPE_SPECS
+ #define CYTHON_USE_TYPE_SPECS 0
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #if PY_VERSION_HEX < 0x03050000
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #elif !defined(CYTHON_USE_ASYNC_SLOTS)
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #undef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 1
+ #undef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 0
+ #undef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 0
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_GIL
+ #define CYTHON_FAST_GIL 0
+ #undef CYTHON_METH_FASTCALL
+ #define CYTHON_METH_FASTCALL 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #ifndef CYTHON_PEP487_INIT_SUBCLASS
+ #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3)
+ #endif
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 1
+ #undef CYTHON_USE_MODULE_STATE
+ #define CYTHON_USE_MODULE_STATE 0
+ #undef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 0
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 0
+ #endif
+ #undef CYTHON_USE_FREELISTS
+ #define CYTHON_USE_FREELISTS 0
+#elif defined(PYPY_VERSION)
+ #define CYTHON_COMPILING_IN_PYPY 1
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_LIMITED_API 0
+ #define CYTHON_COMPILING_IN_GRAAL 0
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #undef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 0
+ #ifndef CYTHON_USE_TYPE_SPECS
+ #define CYTHON_USE_TYPE_SPECS 0
+ #endif
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #if PY_VERSION_HEX < 0x03050000
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #elif !defined(CYTHON_USE_ASYNC_SLOTS)
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #undef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 1
+ #undef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 0
+ #undef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 0
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_GIL
+ #define CYTHON_FAST_GIL 0
+ #undef CYTHON_METH_FASTCALL
+ #define CYTHON_METH_FASTCALL 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #ifndef CYTHON_PEP487_INIT_SUBCLASS
+ #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3)
+ #endif
+ #if PY_VERSION_HEX < 0x03090000
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 0
+ #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT)
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 1
+ #endif
+ #undef CYTHON_USE_MODULE_STATE
+ #define CYTHON_USE_MODULE_STATE 0
+ #undef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1 && PYPY_VERSION_NUM >= 0x07030C00)
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 0
+ #endif
+ #undef CYTHON_USE_FREELISTS
+ #define CYTHON_USE_FREELISTS 0
+#elif defined(CYTHON_LIMITED_API)
+ #ifdef Py_LIMITED_API
+ #undef __PYX_LIMITED_VERSION_HEX
+ #define __PYX_LIMITED_VERSION_HEX Py_LIMITED_API
+ #endif
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_LIMITED_API 1
+ #define CYTHON_COMPILING_IN_GRAAL 0
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #undef CYTHON_CLINE_IN_TRACEBACK
+ #define CYTHON_CLINE_IN_TRACEBACK 0
+ #undef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 0
+ #undef CYTHON_USE_TYPE_SPECS
+ #define CYTHON_USE_TYPE_SPECS 1
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #undef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 0
+ #ifndef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #endif
+ #undef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #undef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 0
+ #undef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 0
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_GIL
+ #define CYTHON_FAST_GIL 0
+ #undef CYTHON_METH_FASTCALL
+ #define CYTHON_METH_FASTCALL 0
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #ifndef CYTHON_PEP487_INIT_SUBCLASS
+ #define CYTHON_PEP487_INIT_SUBCLASS 1
+ #endif
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 0
+ #undef CYTHON_USE_MODULE_STATE
+ #define CYTHON_USE_MODULE_STATE 1
+ #ifndef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 0
+ #endif
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 0
+ #endif
+ #undef CYTHON_USE_FREELISTS
+ #define CYTHON_USE_FREELISTS 0
+#elif defined(Py_GIL_DISABLED) || defined(Py_NOGIL)
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_CPYTHON 0
+ #define CYTHON_COMPILING_IN_LIMITED_API 0
+ #define CYTHON_COMPILING_IN_GRAAL 0
+ #define CYTHON_COMPILING_IN_NOGIL 1
+ #ifndef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 1
+ #endif
+ #ifndef CYTHON_USE_TYPE_SPECS
+ #define CYTHON_USE_TYPE_SPECS 0
+ #endif
+ #undef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 0
+ #ifndef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #ifndef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 0
+ #endif
+ #undef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 0
+ #ifndef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 1
+ #endif
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #ifndef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 1
+ #endif
+ #ifndef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 1
+ #endif
+ #undef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 0
+ #undef CYTHON_FAST_GIL
+ #define CYTHON_FAST_GIL 0
+ #ifndef CYTHON_METH_FASTCALL
+ #define CYTHON_METH_FASTCALL 1
+ #endif
+ #undef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 0
+ #ifndef CYTHON_PEP487_INIT_SUBCLASS
+ #define CYTHON_PEP487_INIT_SUBCLASS 1
+ #endif
+ #ifndef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 1
+ #endif
+ #ifndef CYTHON_USE_MODULE_STATE
+ #define CYTHON_USE_MODULE_STATE 0
+ #endif
+ #ifndef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 1
+ #endif
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 1
+ #endif
+ #ifndef CYTHON_USE_FREELISTS
+ #define CYTHON_USE_FREELISTS 0
+ #endif
+#else
+ #define CYTHON_COMPILING_IN_PYPY 0
+ #define CYTHON_COMPILING_IN_CPYTHON 1
+ #define CYTHON_COMPILING_IN_LIMITED_API 0
+ #define CYTHON_COMPILING_IN_GRAAL 0
+ #define CYTHON_COMPILING_IN_NOGIL 0
+ #ifndef CYTHON_USE_TYPE_SLOTS
+ #define CYTHON_USE_TYPE_SLOTS 1
+ #endif
+ #ifndef CYTHON_USE_TYPE_SPECS
+ #define CYTHON_USE_TYPE_SPECS 0
+ #endif
+ #ifndef CYTHON_USE_PYTYPE_LOOKUP
+ #define CYTHON_USE_PYTYPE_LOOKUP 1
+ #endif
+ #if PY_MAJOR_VERSION < 3
+ #undef CYTHON_USE_ASYNC_SLOTS
+ #define CYTHON_USE_ASYNC_SLOTS 0
+ #elif !defined(CYTHON_USE_ASYNC_SLOTS)
+ #define CYTHON_USE_ASYNC_SLOTS 1
+ #endif
+ #ifndef CYTHON_USE_PYLONG_INTERNALS
+ #define CYTHON_USE_PYLONG_INTERNALS 1
+ #endif
+ #ifndef CYTHON_USE_PYLIST_INTERNALS
+ #define CYTHON_USE_PYLIST_INTERNALS 1
+ #endif
+ #ifndef CYTHON_USE_UNICODE_INTERNALS
+ #define CYTHON_USE_UNICODE_INTERNALS 1
+ #endif
+ #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2
+ #undef CYTHON_USE_UNICODE_WRITER
+ #define CYTHON_USE_UNICODE_WRITER 0
+ #elif !defined(CYTHON_USE_UNICODE_WRITER)
+ #define CYTHON_USE_UNICODE_WRITER 1
+ #endif
+ #ifndef CYTHON_AVOID_BORROWED_REFS
+ #define CYTHON_AVOID_BORROWED_REFS 0
+ #endif
+ #ifndef CYTHON_ASSUME_SAFE_MACROS
+ #define CYTHON_ASSUME_SAFE_MACROS 1
+ #endif
+ #ifndef CYTHON_UNPACK_METHODS
+ #define CYTHON_UNPACK_METHODS 1
+ #endif
+ #ifndef CYTHON_FAST_THREAD_STATE
+ #define CYTHON_FAST_THREAD_STATE 1
+ #endif
+ #ifndef CYTHON_FAST_GIL
+ #define CYTHON_FAST_GIL (PY_MAJOR_VERSION < 3 || PY_VERSION_HEX >= 0x03060000 && PY_VERSION_HEX < 0x030C00A6)
+ #endif
+ #ifndef CYTHON_METH_FASTCALL
+ #define CYTHON_METH_FASTCALL (PY_VERSION_HEX >= 0x030700A1)
+ #endif
+ #ifndef CYTHON_FAST_PYCALL
+ #define CYTHON_FAST_PYCALL 1
+ #endif
+ #ifndef CYTHON_PEP487_INIT_SUBCLASS
+ #define CYTHON_PEP487_INIT_SUBCLASS 1
+ #endif
+ #if PY_VERSION_HEX < 0x03050000
+ #undef CYTHON_PEP489_MULTI_PHASE_INIT
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 0
+ #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT)
+ #define CYTHON_PEP489_MULTI_PHASE_INIT 1
+ #endif
+ #ifndef CYTHON_USE_MODULE_STATE
+ #define CYTHON_USE_MODULE_STATE 0
+ #endif
+ #if PY_VERSION_HEX < 0x030400a1
+ #undef CYTHON_USE_TP_FINALIZE
+ #define CYTHON_USE_TP_FINALIZE 0
+ #elif !defined(CYTHON_USE_TP_FINALIZE)
+ #define CYTHON_USE_TP_FINALIZE 1
+ #endif
+ #if PY_VERSION_HEX < 0x030600B1
+ #undef CYTHON_USE_DICT_VERSIONS
+ #define CYTHON_USE_DICT_VERSIONS 0
+ #elif !defined(CYTHON_USE_DICT_VERSIONS)
+ #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX < 0x030C00A5)
+ #endif
+ #if PY_VERSION_HEX < 0x030700A3
+ #undef CYTHON_USE_EXC_INFO_STACK
+ #define CYTHON_USE_EXC_INFO_STACK 0
+ #elif !defined(CYTHON_USE_EXC_INFO_STACK)
+ #define CYTHON_USE_EXC_INFO_STACK 1
+ #endif
+ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC
+ #define CYTHON_UPDATE_DESCRIPTOR_DOC 1
+ #endif
+ #ifndef CYTHON_USE_FREELISTS
+ #define CYTHON_USE_FREELISTS 1
+ #endif
+#endif
+#if !defined(CYTHON_FAST_PYCCALL)
+#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1)
+#endif
+#if !defined(CYTHON_VECTORCALL)
+#define CYTHON_VECTORCALL (CYTHON_FAST_PYCCALL && PY_VERSION_HEX >= 0x030800B1)
+#endif
+#define CYTHON_BACKPORT_VECTORCALL (CYTHON_METH_FASTCALL && PY_VERSION_HEX < 0x030800B1)
+#if CYTHON_USE_PYLONG_INTERNALS
+ #if PY_MAJOR_VERSION < 3
+ #include "longintrepr.h"
+ #endif
+ #undef SHIFT
+ #undef BASE
+ #undef MASK
+ #ifdef SIZEOF_VOID_P
+ enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) };
+ #endif
+#endif
+#ifndef __has_attribute
+ #define __has_attribute(x) 0
+#endif
+#ifndef __has_cpp_attribute
+ #define __has_cpp_attribute(x) 0
+#endif
+#ifndef CYTHON_RESTRICT
+ #if defined(__GNUC__)
+ #define CYTHON_RESTRICT __restrict__
+ #elif defined(_MSC_VER) && _MSC_VER >= 1400
+ #define CYTHON_RESTRICT __restrict
+ #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+ #define CYTHON_RESTRICT restrict
+ #else
+ #define CYTHON_RESTRICT
+ #endif
+#endif
+#ifndef CYTHON_UNUSED
+ #if defined(__cplusplus)
+ /* for clang __has_cpp_attribute(maybe_unused) is true even before C++17
+ * but leads to warnings with -pedantic, since it is a C++17 feature */
+ #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L)
+ #if __has_cpp_attribute(maybe_unused)
+ #define CYTHON_UNUSED [[maybe_unused]]
+ #endif
+ #endif
+ #endif
+#endif
+#ifndef CYTHON_UNUSED
+# if defined(__GNUC__)
+# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4))
+# define CYTHON_UNUSED __attribute__ ((__unused__))
+# else
+# define CYTHON_UNUSED
+# endif
+# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER))
+# define CYTHON_UNUSED __attribute__ ((__unused__))
+# else
+# define CYTHON_UNUSED
+# endif
+#endif
+#ifndef CYTHON_UNUSED_VAR
+# if defined(__cplusplus)
+ template void CYTHON_UNUSED_VAR( const T& ) { }
+# else
+# define CYTHON_UNUSED_VAR(x) (void)(x)
+# endif
+#endif
+#ifndef CYTHON_MAYBE_UNUSED_VAR
+ #define CYTHON_MAYBE_UNUSED_VAR(x) CYTHON_UNUSED_VAR(x)
+#endif
+#ifndef CYTHON_NCP_UNUSED
+# if CYTHON_COMPILING_IN_CPYTHON
+# define CYTHON_NCP_UNUSED
+# else
+# define CYTHON_NCP_UNUSED CYTHON_UNUSED
+# endif
+#endif
+#ifndef CYTHON_USE_CPP_STD_MOVE
+ #if defined(__cplusplus) && (\
+ __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1600))
+ #define CYTHON_USE_CPP_STD_MOVE 1
+ #else
+ #define CYTHON_USE_CPP_STD_MOVE 0
+ #endif
+#endif
+#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None)
+#ifdef _MSC_VER
+ #ifndef _MSC_STDINT_H_
+ #if _MSC_VER < 1300
+ typedef unsigned char uint8_t;
+ typedef unsigned short uint16_t;
+ typedef unsigned int uint32_t;
+ #else
+ typedef unsigned __int8 uint8_t;
+ typedef unsigned __int16 uint16_t;
+ typedef unsigned __int32 uint32_t;
+ #endif
+ #endif
+ #if _MSC_VER < 1300
+ #ifdef _WIN64
+ typedef unsigned long long __pyx_uintptr_t;
+ #else
+ typedef unsigned int __pyx_uintptr_t;
+ #endif
+ #else
+ #ifdef _WIN64
+ typedef unsigned __int64 __pyx_uintptr_t;
+ #else
+ typedef unsigned __int32 __pyx_uintptr_t;
+ #endif
+ #endif
+#else
+ #include
+ typedef uintptr_t __pyx_uintptr_t;
+#endif
+#ifndef CYTHON_FALLTHROUGH
+ #if defined(__cplusplus)
+ /* for clang __has_cpp_attribute(fallthrough) is true even before C++17
+ * but leads to warnings with -pedantic, since it is a C++17 feature */
+ #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L)
+ #if __has_cpp_attribute(fallthrough)
+ #define CYTHON_FALLTHROUGH [[fallthrough]]
+ #endif
+ #endif
+ #ifndef CYTHON_FALLTHROUGH
+ #if __has_cpp_attribute(clang::fallthrough)
+ #define CYTHON_FALLTHROUGH [[clang::fallthrough]]
+ #elif __has_cpp_attribute(gnu::fallthrough)
+ #define CYTHON_FALLTHROUGH [[gnu::fallthrough]]
+ #endif
+ #endif
+ #endif
+ #ifndef CYTHON_FALLTHROUGH
+ #if __has_attribute(fallthrough)
+ #define CYTHON_FALLTHROUGH __attribute__((fallthrough))
+ #else
+ #define CYTHON_FALLTHROUGH
+ #endif
+ #endif
+ #if defined(__clang__) && defined(__apple_build_version__)
+ #if __apple_build_version__ < 7000000
+ #undef CYTHON_FALLTHROUGH
+ #define CYTHON_FALLTHROUGH
+ #endif
+ #endif
+#endif
+#ifdef __cplusplus
+ template
+ struct __PYX_IS_UNSIGNED_IMPL {static const bool value = T(0) < T(-1);};
+ #define __PYX_IS_UNSIGNED(type) (__PYX_IS_UNSIGNED_IMPL::value)
+#else
+ #define __PYX_IS_UNSIGNED(type) (((type)-1) > 0)
+#endif
+#if CYTHON_COMPILING_IN_PYPY == 1
+ #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x030A0000)
+#else
+ #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000)
+#endif
+#define __PYX_REINTERPRET_FUNCION(func_pointer, other_pointer) ((func_pointer)(void(*)(void))(other_pointer))
+
+#ifndef CYTHON_INLINE
+ #if defined(__clang__)
+ #define CYTHON_INLINE __inline__ __attribute__ ((__unused__))
+ #elif defined(__GNUC__)
+ #define CYTHON_INLINE __inline__
+ #elif defined(_MSC_VER)
+ #define CYTHON_INLINE __inline
+ #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+ #define CYTHON_INLINE inline
+ #else
+ #define CYTHON_INLINE
+ #endif
+#endif
+
+#define __PYX_BUILD_PY_SSIZE_T "n"
+#define CYTHON_FORMAT_SSIZE_T "z"
+#if PY_MAJOR_VERSION < 3
+ #define __Pyx_BUILTIN_MODULE_NAME "__builtin__"
+ #define __Pyx_DefaultClassType PyClass_Type
+ #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\
+ PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
+#else
+ #define __Pyx_BUILTIN_MODULE_NAME "builtins"
+ #define __Pyx_DefaultClassType PyType_Type
+#if CYTHON_COMPILING_IN_LIMITED_API
+ static CYTHON_INLINE PyObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f,
+ PyObject *code, PyObject *c, PyObject* n, PyObject *v,
+ PyObject *fv, PyObject *cell, PyObject* fn,
+ PyObject *name, int fline, PyObject *lnos) {
+ PyObject *exception_table = NULL;
+ PyObject *types_module=NULL, *code_type=NULL, *result=NULL;
+ #if __PYX_LIMITED_VERSION_HEX < 0x030B0000
+ PyObject *version_info;
+ PyObject *py_minor_version = NULL;
+ #endif
+ long minor_version = 0;
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+ #if __PYX_LIMITED_VERSION_HEX >= 0x030B0000
+ minor_version = 11;
+ #else
+ if (!(version_info = PySys_GetObject("version_info"))) goto end;
+ if (!(py_minor_version = PySequence_GetItem(version_info, 1))) goto end;
+ minor_version = PyLong_AsLong(py_minor_version);
+ Py_DECREF(py_minor_version);
+ if (minor_version == -1 && PyErr_Occurred()) goto end;
+ #endif
+ if (!(types_module = PyImport_ImportModule("types"))) goto end;
+ if (!(code_type = PyObject_GetAttrString(types_module, "CodeType"))) goto end;
+ if (minor_version <= 7) {
+ (void)p;
+ result = PyObject_CallFunction(code_type, "iiiiiOOOOOOiOO", a, k, l, s, f, code,
+ c, n, v, fn, name, fline, lnos, fv, cell);
+ } else if (minor_version <= 10) {
+ result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOiOO", a,p, k, l, s, f, code,
+ c, n, v, fn, name, fline, lnos, fv, cell);
+ } else {
+ if (!(exception_table = PyBytes_FromStringAndSize(NULL, 0))) goto end;
+ result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOOiOO", a,p, k, l, s, f, code,
+ c, n, v, fn, name, name, fline, lnos, exception_table, fv, cell);
+ }
+ end:
+ Py_XDECREF(code_type);
+ Py_XDECREF(exception_table);
+ Py_XDECREF(types_module);
+ if (type) {
+ PyErr_Restore(type, value, traceback);
+ }
+ return result;
+ }
+ #ifndef CO_OPTIMIZED
+ #define CO_OPTIMIZED 0x0001
+ #endif
+ #ifndef CO_NEWLOCALS
+ #define CO_NEWLOCALS 0x0002
+ #endif
+ #ifndef CO_VARARGS
+ #define CO_VARARGS 0x0004
+ #endif
+ #ifndef CO_VARKEYWORDS
+ #define CO_VARKEYWORDS 0x0008
+ #endif
+ #ifndef CO_ASYNC_GENERATOR
+ #define CO_ASYNC_GENERATOR 0x0200
+ #endif
+ #ifndef CO_GENERATOR
+ #define CO_GENERATOR 0x0020
+ #endif
+ #ifndef CO_COROUTINE
+ #define CO_COROUTINE 0x0080
+ #endif
+#elif PY_VERSION_HEX >= 0x030B0000
+ static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f,
+ PyObject *code, PyObject *c, PyObject* n, PyObject *v,
+ PyObject *fv, PyObject *cell, PyObject* fn,
+ PyObject *name, int fline, PyObject *lnos) {
+ PyCodeObject *result;
+ PyObject *empty_bytes = PyBytes_FromStringAndSize("", 0);
+ if (!empty_bytes) return NULL;
+ result =
+ #if PY_VERSION_HEX >= 0x030C0000
+ PyUnstable_Code_NewWithPosOnlyArgs
+ #else
+ PyCode_NewWithPosOnlyArgs
+ #endif
+ (a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, name, fline, lnos, empty_bytes);
+ Py_DECREF(empty_bytes);
+ return result;
+ }
+#elif PY_VERSION_HEX >= 0x030800B2 && !CYTHON_COMPILING_IN_PYPY
+ #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\
+ PyCode_NewWithPosOnlyArgs(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
+#else
+ #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\
+ PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
+#endif
+#endif
+#if PY_VERSION_HEX >= 0x030900A4 || defined(Py_IS_TYPE)
+ #define __Pyx_IS_TYPE(ob, type) Py_IS_TYPE(ob, type)
+#else
+ #define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type))
+#endif
+#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is)
+ #define __Pyx_Py_Is(x, y) Py_Is(x, y)
+#else
+ #define __Pyx_Py_Is(x, y) ((x) == (y))
+#endif
+#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone)
+ #define __Pyx_Py_IsNone(ob) Py_IsNone(ob)
+#else
+ #define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None)
+#endif
+#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue)
+ #define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob)
+#else
+ #define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True)
+#endif
+#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse)
+ #define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob)
+#else
+ #define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False)
+#endif
+#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj))
+#if PY_VERSION_HEX >= 0x030900F0 && !CYTHON_COMPILING_IN_PYPY
+ #define __Pyx_PyObject_GC_IsFinalized(o) PyObject_GC_IsFinalized(o)
+#else
+ #define __Pyx_PyObject_GC_IsFinalized(o) _PyGC_FINALIZED(o)
+#endif
+#ifndef CO_COROUTINE
+ #define CO_COROUTINE 0x80
+#endif
+#ifndef CO_ASYNC_GENERATOR
+ #define CO_ASYNC_GENERATOR 0x200
+#endif
+#ifndef Py_TPFLAGS_CHECKTYPES
+ #define Py_TPFLAGS_CHECKTYPES 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_INDEX
+ #define Py_TPFLAGS_HAVE_INDEX 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_NEWBUFFER
+ #define Py_TPFLAGS_HAVE_NEWBUFFER 0
+#endif
+#ifndef Py_TPFLAGS_HAVE_FINALIZE
+ #define Py_TPFLAGS_HAVE_FINALIZE 0
+#endif
+#ifndef Py_TPFLAGS_SEQUENCE
+ #define Py_TPFLAGS_SEQUENCE 0
+#endif
+#ifndef Py_TPFLAGS_MAPPING
+ #define Py_TPFLAGS_MAPPING 0
+#endif
+#ifndef METH_STACKLESS
+ #define METH_STACKLESS 0
+#endif
+#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL)
+ #ifndef METH_FASTCALL
+ #define METH_FASTCALL 0x80
+ #endif
+ typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs);
+ typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args,
+ Py_ssize_t nargs, PyObject *kwnames);
+#else
+ #if PY_VERSION_HEX >= 0x030d00A4
+ # define __Pyx_PyCFunctionFast PyCFunctionFast
+ # define __Pyx_PyCFunctionFastWithKeywords PyCFunctionFastWithKeywords
+ #else
+ # define __Pyx_PyCFunctionFast _PyCFunctionFast
+ # define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords
+ #endif
+#endif
+#if CYTHON_METH_FASTCALL
+ #define __Pyx_METH_FASTCALL METH_FASTCALL
+ #define __Pyx_PyCFunction_FastCall __Pyx_PyCFunctionFast
+ #define __Pyx_PyCFunction_FastCallWithKeywords __Pyx_PyCFunctionFastWithKeywords
+#else
+ #define __Pyx_METH_FASTCALL METH_VARARGS
+ #define __Pyx_PyCFunction_FastCall PyCFunction
+ #define __Pyx_PyCFunction_FastCallWithKeywords PyCFunctionWithKeywords
+#endif
+#if CYTHON_VECTORCALL
+ #define __pyx_vectorcallfunc vectorcallfunc
+ #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET PY_VECTORCALL_ARGUMENTS_OFFSET
+ #define __Pyx_PyVectorcall_NARGS(n) PyVectorcall_NARGS((size_t)(n))
+#elif CYTHON_BACKPORT_VECTORCALL
+ typedef PyObject *(*__pyx_vectorcallfunc)(PyObject *callable, PyObject *const *args,
+ size_t nargsf, PyObject *kwnames);
+ #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET ((size_t)1 << (8 * sizeof(size_t) - 1))
+ #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(((size_t)(n)) & ~__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET))
+#else
+ #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET 0
+ #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(n))
+#endif
+#if PY_MAJOR_VERSION >= 0x030900B1
+#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_CheckExact(func)
+#else
+#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_Check(func)
+#endif
+#define __Pyx_CyOrPyCFunction_Check(func) PyCFunction_Check(func)
+#if CYTHON_COMPILING_IN_CPYTHON
+#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) (((PyCFunctionObject*)(func))->m_ml->ml_meth)
+#elif !CYTHON_COMPILING_IN_LIMITED_API
+#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) PyCFunction_GET_FUNCTION(func)
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON
+#define __Pyx_CyOrPyCFunction_GET_FLAGS(func) (((PyCFunctionObject*)(func))->m_ml->ml_flags)
+static CYTHON_INLINE PyObject* __Pyx_CyOrPyCFunction_GET_SELF(PyObject *func) {
+ return (__Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_STATIC) ? NULL : ((PyCFunctionObject*)func)->m_self;
+}
+#endif
+static CYTHON_INLINE int __Pyx__IsSameCFunction(PyObject *func, void *cfunc) {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ return PyCFunction_Check(func) && PyCFunction_GetFunction(func) == (PyCFunction) cfunc;
+#else
+ return PyCFunction_Check(func) && PyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc;
+#endif
+}
+#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCFunction(func, cfunc)
+#if __PYX_LIMITED_VERSION_HEX < 0x030900B1
+ #define __Pyx_PyType_FromModuleAndSpec(m, s, b) ((void)m, PyType_FromSpecWithBases(s, b))
+ typedef PyObject *(*__Pyx_PyCMethod)(PyObject *, PyTypeObject *, PyObject *const *, size_t, PyObject *);
+#else
+ #define __Pyx_PyType_FromModuleAndSpec(m, s, b) PyType_FromModuleAndSpec(m, s, b)
+ #define __Pyx_PyCMethod PyCMethod
+#endif
+#ifndef METH_METHOD
+ #define METH_METHOD 0x200
+#endif
+#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc)
+ #define PyObject_Malloc(s) PyMem_Malloc(s)
+ #define PyObject_Free(p) PyMem_Free(p)
+ #define PyObject_Realloc(p) PyMem_Realloc(p)
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+ #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0)
+ #define __Pyx_PyFrame_SetLineNumber(frame, lineno)
+#else
+ #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0)
+ #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno)
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+ #define __Pyx_PyThreadState_Current PyThreadState_Get()
+#elif !CYTHON_FAST_THREAD_STATE
+ #define __Pyx_PyThreadState_Current PyThreadState_GET()
+#elif PY_VERSION_HEX >= 0x030d00A1
+ #define __Pyx_PyThreadState_Current PyThreadState_GetUnchecked()
+#elif PY_VERSION_HEX >= 0x03060000
+ #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet()
+#elif PY_VERSION_HEX >= 0x03000000
+ #define __Pyx_PyThreadState_Current PyThreadState_GET()
+#else
+ #define __Pyx_PyThreadState_Current _PyThreadState_Current
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+static CYTHON_INLINE void *__Pyx_PyModule_GetState(PyObject *op)
+{
+ void *result;
+ result = PyModule_GetState(op);
+ if (!result)
+ Py_FatalError("Couldn't find the module state");
+ return result;
+}
+#endif
+#define __Pyx_PyObject_GetSlot(obj, name, func_ctype) __Pyx_PyType_GetSlot(Py_TYPE(obj), name, func_ctype)
+#if CYTHON_COMPILING_IN_LIMITED_API
+ #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((func_ctype) PyType_GetSlot((type), Py_##name))
+#else
+ #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((type)->name)
+#endif
+#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT)
+#include "pythread.h"
+#define Py_tss_NEEDS_INIT 0
+typedef int Py_tss_t;
+static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) {
+ *key = PyThread_create_key();
+ return 0;
+}
+static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) {
+ Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t));
+ *key = Py_tss_NEEDS_INIT;
+ return key;
+}
+static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) {
+ PyObject_Free(key);
+}
+static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) {
+ return *key != Py_tss_NEEDS_INIT;
+}
+static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) {
+ PyThread_delete_key(*key);
+ *key = Py_tss_NEEDS_INIT;
+}
+static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) {
+ return PyThread_set_key_value(*key, value);
+}
+static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) {
+ return PyThread_get_key_value(*key);
+}
+#endif
+#if PY_MAJOR_VERSION < 3
+ #if CYTHON_COMPILING_IN_PYPY
+ #if PYPY_VERSION_NUM < 0x07030600
+ #if defined(__cplusplus) && __cplusplus >= 201402L
+ [[deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")]]
+ #elif defined(__GNUC__) || defined(__clang__)
+ __attribute__ ((__deprecated__("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")))
+ #elif defined(_MSC_VER)
+ __declspec(deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6"))
+ #endif
+ static CYTHON_INLINE int PyGILState_Check(void) {
+ return 0;
+ }
+ #else // PYPY_VERSION_NUM < 0x07030600
+ #endif // PYPY_VERSION_NUM < 0x07030600
+ #else
+ static CYTHON_INLINE int PyGILState_Check(void) {
+ PyThreadState * tstate = _PyThreadState_Current;
+ return tstate && (tstate == PyGILState_GetThisThreadState());
+ }
+ #endif
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030d0000 || defined(_PyDict_NewPresized)
+#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n))
+#else
+#define __Pyx_PyDict_NewPresized(n) PyDict_New()
+#endif
+#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION
+ #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y)
+ #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y)
+#else
+ #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y)
+ #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y)
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX > 0x030600B4 && PY_VERSION_HEX < 0x030d0000 && CYTHON_USE_UNICODE_INTERNALS
+#define __Pyx_PyDict_GetItemStrWithError(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash)
+static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStr(PyObject *dict, PyObject *name) {
+ PyObject *res = __Pyx_PyDict_GetItemStrWithError(dict, name);
+ if (res == NULL) PyErr_Clear();
+ return res;
+}
+#elif PY_MAJOR_VERSION >= 3 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07020000)
+#define __Pyx_PyDict_GetItemStrWithError PyDict_GetItemWithError
+#define __Pyx_PyDict_GetItemStr PyDict_GetItem
+#else
+static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStrWithError(PyObject *dict, PyObject *name) {
+#if CYTHON_COMPILING_IN_PYPY
+ return PyDict_GetItem(dict, name);
+#else
+ PyDictEntry *ep;
+ PyDictObject *mp = (PyDictObject*) dict;
+ long hash = ((PyStringObject *) name)->ob_shash;
+ assert(hash != -1);
+ ep = (mp->ma_lookup)(mp, name, hash);
+ if (ep == NULL) {
+ return NULL;
+ }
+ return ep->me_value;
+#endif
+}
+#define __Pyx_PyDict_GetItemStr PyDict_GetItem
+#endif
+#if CYTHON_USE_TYPE_SLOTS
+ #define __Pyx_PyType_GetFlags(tp) (((PyTypeObject *)tp)->tp_flags)
+ #define __Pyx_PyType_HasFeature(type, feature) ((__Pyx_PyType_GetFlags(type) & (feature)) != 0)
+ #define __Pyx_PyObject_GetIterNextFunc(obj) (Py_TYPE(obj)->tp_iternext)
+#else
+ #define __Pyx_PyType_GetFlags(tp) (PyType_GetFlags((PyTypeObject *)tp))
+ #define __Pyx_PyType_HasFeature(type, feature) PyType_HasFeature(type, feature)
+ #define __Pyx_PyObject_GetIterNextFunc(obj) PyIter_Next
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+ #define __Pyx_SetItemOnTypeDict(tp, k, v) PyObject_GenericSetAttr((PyObject*)tp, k, v)
+#else
+ #define __Pyx_SetItemOnTypeDict(tp, k, v) PyDict_SetItem(tp->tp_dict, k, v)
+#endif
+#if CYTHON_USE_TYPE_SPECS && PY_VERSION_HEX >= 0x03080000
+#define __Pyx_PyHeapTypeObject_GC_Del(obj) {\
+ PyTypeObject *type = Py_TYPE((PyObject*)obj);\
+ assert(__Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE));\
+ PyObject_GC_Del(obj);\
+ Py_DECREF(type);\
+}
+#else
+#define __Pyx_PyHeapTypeObject_GC_Del(obj) PyObject_GC_Del(obj)
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+ #define CYTHON_PEP393_ENABLED 1
+ #define __Pyx_PyUnicode_READY(op) (0)
+ #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GetLength(u)
+ #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_ReadChar(u, i)
+ #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((void)u, 1114111U)
+ #define __Pyx_PyUnicode_KIND(u) ((void)u, (0))
+ #define __Pyx_PyUnicode_DATA(u) ((void*)u)
+ #define __Pyx_PyUnicode_READ(k, d, i) ((void)k, PyUnicode_ReadChar((PyObject*)(d), i))
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GetLength(u))
+#elif PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND)
+ #define CYTHON_PEP393_ENABLED 1
+ #if PY_VERSION_HEX >= 0x030C0000
+ #define __Pyx_PyUnicode_READY(op) (0)
+ #else
+ #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\
+ 0 : _PyUnicode_Ready((PyObject *)(op)))
+ #endif
+ #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u)
+ #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i)
+ #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u)
+ #define __Pyx_PyUnicode_KIND(u) ((int)PyUnicode_KIND(u))
+ #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u)
+ #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i)
+ #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, (Py_UCS4) ch)
+ #if PY_VERSION_HEX >= 0x030C0000
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u))
+ #else
+ #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length))
+ #else
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u)))
+ #endif
+ #endif
+#else
+ #define CYTHON_PEP393_ENABLED 0
+ #define PyUnicode_1BYTE_KIND 1
+ #define PyUnicode_2BYTE_KIND 2
+ #define PyUnicode_4BYTE_KIND 4
+ #define __Pyx_PyUnicode_READY(op) (0)
+ #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u)
+ #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i]))
+ #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535U : 1114111U)
+ #define __Pyx_PyUnicode_KIND(u) ((int)sizeof(Py_UNICODE))
+ #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u))
+ #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i]))
+ #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = (Py_UNICODE) ch)
+ #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u))
+#endif
+#if CYTHON_COMPILING_IN_PYPY
+ #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b)
+ #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b)
+#else
+ #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b)
+ #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\
+ PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b))
+#endif
+#if CYTHON_COMPILING_IN_PYPY
+ #if !defined(PyUnicode_DecodeUnicodeEscape)
+ #define PyUnicode_DecodeUnicodeEscape(s, size, errors) PyUnicode_Decode(s, size, "unicode_escape", errors)
+ #endif
+ #if !defined(PyUnicode_Contains) || (PY_MAJOR_VERSION == 2 && PYPY_VERSION_NUM < 0x07030500)
+ #undef PyUnicode_Contains
+ #define PyUnicode_Contains(u, s) PySequence_Contains(u, s)
+ #endif
+ #if !defined(PyByteArray_Check)
+ #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type)
+ #endif
+ #if !defined(PyObject_Format)
+ #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt)
+ #endif
+#endif
+#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b))
+#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b))
+#if PY_MAJOR_VERSION >= 3
+ #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b)
+#else
+ #define __Pyx_PyString_Format(a, b) PyString_Format(a, b)
+#endif
+#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII)
+ #define PyObject_ASCII(o) PyObject_Repr(o)
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyBaseString_Type PyUnicode_Type
+ #define PyStringObject PyUnicodeObject
+ #define PyString_Type PyUnicode_Type
+ #define PyString_Check PyUnicode_Check
+ #define PyString_CheckExact PyUnicode_CheckExact
+#ifndef PyObject_Unicode
+ #define PyObject_Unicode PyObject_Str
+#endif
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj)
+ #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj)
+#else
+ #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj))
+ #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj))
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON
+ #define __Pyx_PySequence_ListKeepNew(obj)\
+ (likely(PyList_CheckExact(obj) && Py_REFCNT(obj) == 1) ? __Pyx_NewRef(obj) : PySequence_List(obj))
+#else
+ #define __Pyx_PySequence_ListKeepNew(obj) PySequence_List(obj)
+#endif
+#ifndef PySet_CheckExact
+ #define PySet_CheckExact(obj) __Pyx_IS_TYPE(obj, &PySet_Type)
+#endif
+#if PY_VERSION_HEX >= 0x030900A4
+ #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt)
+ #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size)
+#else
+ #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt)
+ #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size)
+#endif
+#if CYTHON_ASSUME_SAFE_MACROS
+ #define __Pyx_PySequence_ITEM(o, i) PySequence_ITEM(o, i)
+ #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq)
+ #define __Pyx_PyTuple_SET_ITEM(o, i, v) (PyTuple_SET_ITEM(o, i, v), (0))
+ #define __Pyx_PyList_SET_ITEM(o, i, v) (PyList_SET_ITEM(o, i, v), (0))
+ #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_GET_SIZE(o)
+ #define __Pyx_PyList_GET_SIZE(o) PyList_GET_SIZE(o)
+ #define __Pyx_PySet_GET_SIZE(o) PySet_GET_SIZE(o)
+ #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_GET_SIZE(o)
+ #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_GET_SIZE(o)
+#else
+ #define __Pyx_PySequence_ITEM(o, i) PySequence_GetItem(o, i)
+ #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq)
+ #define __Pyx_PyTuple_SET_ITEM(o, i, v) PyTuple_SetItem(o, i, v)
+ #define __Pyx_PyList_SET_ITEM(o, i, v) PyList_SetItem(o, i, v)
+ #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_Size(o)
+ #define __Pyx_PyList_GET_SIZE(o) PyList_Size(o)
+ #define __Pyx_PySet_GET_SIZE(o) PySet_Size(o)
+ #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_Size(o)
+ #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_Size(o)
+#endif
+#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1
+ #define __Pyx_PyImport_AddModuleRef(name) PyImport_AddModuleRef(name)
+#else
+ static CYTHON_INLINE PyObject *__Pyx_PyImport_AddModuleRef(const char *name) {
+ PyObject *module = PyImport_AddModule(name);
+ Py_XINCREF(module);
+ return module;
+ }
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyIntObject PyLongObject
+ #define PyInt_Type PyLong_Type
+ #define PyInt_Check(op) PyLong_Check(op)
+ #define PyInt_CheckExact(op) PyLong_CheckExact(op)
+ #define __Pyx_Py3Int_Check(op) PyLong_Check(op)
+ #define __Pyx_Py3Int_CheckExact(op) PyLong_CheckExact(op)
+ #define PyInt_FromString PyLong_FromString
+ #define PyInt_FromUnicode PyLong_FromUnicode
+ #define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
+ #define PyInt_FromSsize_t PyLong_FromSsize_t
+ #define PyInt_AsLong PyLong_AsLong
+ #define PyInt_AS_LONG PyLong_AS_LONG
+ #define PyInt_AsSsize_t PyLong_AsSsize_t
+ #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask
+ #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask
+ #define PyNumber_Int PyNumber_Long
+#else
+ #define __Pyx_Py3Int_Check(op) (PyLong_Check(op) || PyInt_Check(op))
+ #define __Pyx_Py3Int_CheckExact(op) (PyLong_CheckExact(op) || PyInt_CheckExact(op))
+#endif
+#if PY_MAJOR_VERSION >= 3
+ #define PyBoolObject PyLongObject
+#endif
+#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY
+ #ifndef PyUnicode_InternFromString
+ #define PyUnicode_InternFromString(s) PyUnicode_FromString(s)
+ #endif
+#endif
+#if PY_VERSION_HEX < 0x030200A4
+ typedef long Py_hash_t;
+ #define __Pyx_PyInt_FromHash_t PyInt_FromLong
+ #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t
+#else
+ #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t
+ #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t
+#endif
+#if CYTHON_USE_ASYNC_SLOTS
+ #if PY_VERSION_HEX >= 0x030500B1
+ #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods
+ #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async)
+ #else
+ #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved))
+ #endif
+#else
+ #define __Pyx_PyType_AsAsync(obj) NULL
+#endif
+#ifndef __Pyx_PyAsyncMethodsStruct
+ typedef struct {
+ unaryfunc am_await;
+ unaryfunc am_aiter;
+ unaryfunc am_anext;
+ } __Pyx_PyAsyncMethodsStruct;
+#endif
+
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ #if !defined(_USE_MATH_DEFINES)
+ #define _USE_MATH_DEFINES
+ #endif
+#endif
+#include
+#ifdef NAN
+#define __PYX_NAN() ((float) NAN)
+#else
+static CYTHON_INLINE float __PYX_NAN() {
+ float value;
+ memset(&value, 0xFF, sizeof(value));
+ return value;
+}
+#endif
+#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL)
+#define __Pyx_truncl trunc
+#else
+#define __Pyx_truncl truncl
+#endif
+
+#define __PYX_MARK_ERR_POS(f_index, lineno) \
+ { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; }
+#define __PYX_ERR(f_index, lineno, Ln_error) \
+ { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; }
+
+#ifdef CYTHON_EXTERN_C
+ #undef __PYX_EXTERN_C
+ #define __PYX_EXTERN_C CYTHON_EXTERN_C
+#elif defined(__PYX_EXTERN_C)
+ #ifdef _MSC_VER
+ #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.")
+ #else
+ #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.
+ #endif
+#else
+ #ifdef __cplusplus
+ #define __PYX_EXTERN_C extern "C"
+ #else
+ #define __PYX_EXTERN_C extern
+ #endif
+#endif
+
+#define __PYX_HAVE__nms__cpu_nms
+#define __PYX_HAVE_API__nms__cpu_nms
+/* Early includes */
+#include
+#include
+
+ /* Using NumPy API declarations from "numpy/__init__.cython-30.pxd" */
+
+#include "numpy/arrayobject.h"
+#include "numpy/ndarrayobject.h"
+#include "numpy/ndarraytypes.h"
+#include "numpy/arrayscalars.h"
+#include "numpy/ufuncobject.h"
+#ifdef _OPENMP
+#include
+#endif /* _OPENMP */
+
+#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS)
+#define CYTHON_WITHOUT_ASSERTIONS
+#endif
+
+typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding;
+ const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry;
+
+#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0
+#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0
+#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8)
+#define __PYX_DEFAULT_STRING_ENCODING ""
+#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString
+#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize
+#define __Pyx_uchar_cast(c) ((unsigned char)c)
+#define __Pyx_long_cast(x) ((long)x)
+#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\
+ (sizeof(type) < sizeof(Py_ssize_t)) ||\
+ (sizeof(type) > sizeof(Py_ssize_t) &&\
+ likely(v < (type)PY_SSIZE_T_MAX ||\
+ v == (type)PY_SSIZE_T_MAX) &&\
+ (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\
+ v == (type)PY_SSIZE_T_MIN))) ||\
+ (sizeof(type) == sizeof(Py_ssize_t) &&\
+ (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\
+ v == (type)PY_SSIZE_T_MAX))) )
+static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) {
+ return (size_t) i < (size_t) limit;
+}
+#if defined (__cplusplus) && __cplusplus >= 201103L
+ #include
+ #define __Pyx_sst_abs(value) std::abs(value)
+#elif SIZEOF_INT >= SIZEOF_SIZE_T
+ #define __Pyx_sst_abs(value) abs(value)
+#elif SIZEOF_LONG >= SIZEOF_SIZE_T
+ #define __Pyx_sst_abs(value) labs(value)
+#elif defined (_MSC_VER)
+ #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value))
+#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+ #define __Pyx_sst_abs(value) llabs(value)
+#elif defined (__GNUC__)
+ #define __Pyx_sst_abs(value) __builtin_llabs(value)
+#else
+ #define __Pyx_sst_abs(value) ((value<0) ? -value : value)
+#endif
+static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s);
+static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*);
+static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length);
+static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char*);
+#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l)
+#define __Pyx_PyBytes_FromString PyBytes_FromString
+#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize
+static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*);
+#if PY_MAJOR_VERSION < 3
+ #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString
+ #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize
+#else
+ #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString
+ #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize
+#endif
+#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s))
+#define __Pyx_PyObject_AsWritableString(s) ((char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsWritableSString(s) ((signed char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s))
+#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s)
+#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s)
+#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s)
+#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s)
+#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s)
+#define __Pyx_PyUnicode_FromOrdinal(o) PyUnicode_FromOrdinal((int)o)
+#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode
+#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj)
+#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None)
+static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b);
+static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*);
+static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*);
+static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x);
+#define __Pyx_PySequence_Tuple(obj)\
+ (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj))
+static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*);
+static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t);
+static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*);
+#if CYTHON_ASSUME_SAFE_MACROS
+#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x))
+#else
+#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x)
+#endif
+#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x))
+#if PY_MAJOR_VERSION >= 3
+#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x))
+#else
+#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x))
+#endif
+#if CYTHON_USE_PYLONG_INTERNALS
+ #if PY_VERSION_HEX >= 0x030C00A7
+ #ifndef _PyLong_SIGN_MASK
+ #define _PyLong_SIGN_MASK 3
+ #endif
+ #ifndef _PyLong_NON_SIZE_BITS
+ #define _PyLong_NON_SIZE_BITS 3
+ #endif
+ #define __Pyx_PyLong_Sign(x) (((PyLongObject*)x)->long_value.lv_tag & _PyLong_SIGN_MASK)
+ #define __Pyx_PyLong_IsNeg(x) ((__Pyx_PyLong_Sign(x) & 2) != 0)
+ #define __Pyx_PyLong_IsNonNeg(x) (!__Pyx_PyLong_IsNeg(x))
+ #define __Pyx_PyLong_IsZero(x) (__Pyx_PyLong_Sign(x) & 1)
+ #define __Pyx_PyLong_IsPos(x) (__Pyx_PyLong_Sign(x) == 0)
+ #define __Pyx_PyLong_CompactValueUnsigned(x) (__Pyx_PyLong_Digits(x)[0])
+ #define __Pyx_PyLong_DigitCount(x) ((Py_ssize_t) (((PyLongObject*)x)->long_value.lv_tag >> _PyLong_NON_SIZE_BITS))
+ #define __Pyx_PyLong_SignedDigitCount(x)\
+ ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * __Pyx_PyLong_DigitCount(x))
+ #if defined(PyUnstable_Long_IsCompact) && defined(PyUnstable_Long_CompactValue)
+ #define __Pyx_PyLong_IsCompact(x) PyUnstable_Long_IsCompact((PyLongObject*) x)
+ #define __Pyx_PyLong_CompactValue(x) PyUnstable_Long_CompactValue((PyLongObject*) x)
+ #else
+ #define __Pyx_PyLong_IsCompact(x) (((PyLongObject*)x)->long_value.lv_tag < (2 << _PyLong_NON_SIZE_BITS))
+ #define __Pyx_PyLong_CompactValue(x) ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * (Py_ssize_t) __Pyx_PyLong_Digits(x)[0])
+ #endif
+ typedef Py_ssize_t __Pyx_compact_pylong;
+ typedef size_t __Pyx_compact_upylong;
+ #else
+ #define __Pyx_PyLong_IsNeg(x) (Py_SIZE(x) < 0)
+ #define __Pyx_PyLong_IsNonNeg(x) (Py_SIZE(x) >= 0)
+ #define __Pyx_PyLong_IsZero(x) (Py_SIZE(x) == 0)
+ #define __Pyx_PyLong_IsPos(x) (Py_SIZE(x) > 0)
+ #define __Pyx_PyLong_CompactValueUnsigned(x) ((Py_SIZE(x) == 0) ? 0 : __Pyx_PyLong_Digits(x)[0])
+ #define __Pyx_PyLong_DigitCount(x) __Pyx_sst_abs(Py_SIZE(x))
+ #define __Pyx_PyLong_SignedDigitCount(x) Py_SIZE(x)
+ #define __Pyx_PyLong_IsCompact(x) (Py_SIZE(x) == 0 || Py_SIZE(x) == 1 || Py_SIZE(x) == -1)
+ #define __Pyx_PyLong_CompactValue(x)\
+ ((Py_SIZE(x) == 0) ? (sdigit) 0 : ((Py_SIZE(x) < 0) ? -(sdigit)__Pyx_PyLong_Digits(x)[0] : (sdigit)__Pyx_PyLong_Digits(x)[0]))
+ typedef sdigit __Pyx_compact_pylong;
+ typedef digit __Pyx_compact_upylong;
+ #endif
+ #if PY_VERSION_HEX >= 0x030C00A5
+ #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->long_value.ob_digit)
+ #else
+ #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->ob_digit)
+ #endif
+#endif
+#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII
+#include
+static int __Pyx_sys_getdefaultencoding_not_ascii;
+static int __Pyx_init_sys_getdefaultencoding_params(void) {
+ PyObject* sys;
+ PyObject* default_encoding = NULL;
+ PyObject* ascii_chars_u = NULL;
+ PyObject* ascii_chars_b = NULL;
+ const char* default_encoding_c;
+ sys = PyImport_ImportModule("sys");
+ if (!sys) goto bad;
+ default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL);
+ Py_DECREF(sys);
+ if (!default_encoding) goto bad;
+ default_encoding_c = PyBytes_AsString(default_encoding);
+ if (!default_encoding_c) goto bad;
+ if (strcmp(default_encoding_c, "ascii") == 0) {
+ __Pyx_sys_getdefaultencoding_not_ascii = 0;
+ } else {
+ char ascii_chars[128];
+ int c;
+ for (c = 0; c < 128; c++) {
+ ascii_chars[c] = (char) c;
+ }
+ __Pyx_sys_getdefaultencoding_not_ascii = 1;
+ ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL);
+ if (!ascii_chars_u) goto bad;
+ ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL);
+ if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) {
+ PyErr_Format(
+ PyExc_ValueError,
+ "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.",
+ default_encoding_c);
+ goto bad;
+ }
+ Py_DECREF(ascii_chars_u);
+ Py_DECREF(ascii_chars_b);
+ }
+ Py_DECREF(default_encoding);
+ return 0;
+bad:
+ Py_XDECREF(default_encoding);
+ Py_XDECREF(ascii_chars_u);
+ Py_XDECREF(ascii_chars_b);
+ return -1;
+}
+#endif
+#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3
+#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL)
+#else
+#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL)
+#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT
+#include
+static char* __PYX_DEFAULT_STRING_ENCODING;
+static int __Pyx_init_sys_getdefaultencoding_params(void) {
+ PyObject* sys;
+ PyObject* default_encoding = NULL;
+ char* default_encoding_c;
+ sys = PyImport_ImportModule("sys");
+ if (!sys) goto bad;
+ default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL);
+ Py_DECREF(sys);
+ if (!default_encoding) goto bad;
+ default_encoding_c = PyBytes_AsString(default_encoding);
+ if (!default_encoding_c) goto bad;
+ __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1);
+ if (!__PYX_DEFAULT_STRING_ENCODING) goto bad;
+ strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c);
+ Py_DECREF(default_encoding);
+ return 0;
+bad:
+ Py_XDECREF(default_encoding);
+ return -1;
+}
+#endif
+#endif
+
+
+/* Test for GCC > 2.95 */
+#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95)))
+ #define likely(x) __builtin_expect(!!(x), 1)
+ #define unlikely(x) __builtin_expect(!!(x), 0)
+#else /* !__GNUC__ or GCC < 2.95 */
+ #define likely(x) (x)
+ #define unlikely(x) (x)
+#endif /* __GNUC__ */
+static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; }
+
+#if !CYTHON_USE_MODULE_STATE
+static PyObject *__pyx_m = NULL;
+#endif
+static int __pyx_lineno;
+static int __pyx_clineno = 0;
+static const char * __pyx_cfilenm = __FILE__;
+static const char *__pyx_filename;
+
+/* Header.proto */
+#if !defined(CYTHON_CCOMPLEX)
+ #if defined(__cplusplus)
+ #define CYTHON_CCOMPLEX 1
+ #elif (defined(_Complex_I) && !defined(_MSC_VER)) || ((defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) && !defined(__STDC_NO_COMPLEX__) && !defined(_MSC_VER))
+ #define CYTHON_CCOMPLEX 1
+ #else
+ #define CYTHON_CCOMPLEX 0
+ #endif
+#endif
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ #include
+ #else
+ #include
+ #endif
+#endif
+#if CYTHON_CCOMPLEX && !defined(__cplusplus) && defined(__sun__) && defined(__GNUC__)
+ #undef _Complex_I
+ #define _Complex_I 1.0fj
+#endif
+
+/* #### Code section: filename_table ### */
+
+static const char *__pyx_f[] = {
+ "nms/cpu_nms.pyx",
+ "__init__.cython-30.pxd",
+ "type.pxd",
+};
+/* #### Code section: utility_code_proto_before_types ### */
+/* ForceInitThreads.proto */
+#ifndef __PYX_FORCE_INIT_THREADS
+ #define __PYX_FORCE_INIT_THREADS 0
+#endif
+
+/* BufferFormatStructs.proto */
+struct __Pyx_StructField_;
+#define __PYX_BUF_FLAGS_PACKED_STRUCT (1 << 0)
+typedef struct {
+ const char* name;
+ struct __Pyx_StructField_* fields;
+ size_t size;
+ size_t arraysize[8];
+ int ndim;
+ char typegroup;
+ char is_unsigned;
+ int flags;
+} __Pyx_TypeInfo;
+typedef struct __Pyx_StructField_ {
+ __Pyx_TypeInfo* type;
+ const char* name;
+ size_t offset;
+} __Pyx_StructField;
+typedef struct {
+ __Pyx_StructField* field;
+ size_t parent_offset;
+} __Pyx_BufFmt_StackElem;
+typedef struct {
+ __Pyx_StructField root;
+ __Pyx_BufFmt_StackElem* head;
+ size_t fmt_offset;
+ size_t new_count, enc_count;
+ size_t struct_alignment;
+ int is_complex;
+ char enc_type;
+ char new_packmode;
+ char enc_packmode;
+ char is_valid_array;
+} __Pyx_BufFmt_Context;
+
+/* #### Code section: numeric_typedefs ### */
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":731
+ * # in Cython to enable them only on the right systems.
+ *
+ * ctypedef npy_int8 int8_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t
+ */
+typedef npy_int8 __pyx_t_5numpy_int8_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":732
+ *
+ * ctypedef npy_int8 int8_t
+ * ctypedef npy_int16 int16_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int32 int32_t
+ * ctypedef npy_int64 int64_t
+ */
+typedef npy_int16 __pyx_t_5numpy_int16_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":733
+ * ctypedef npy_int8 int8_t
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_int64 int64_t
+ * #ctypedef npy_int96 int96_t
+ */
+typedef npy_int32 __pyx_t_5numpy_int32_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":734
+ * ctypedef npy_int16 int16_t
+ * ctypedef npy_int32 int32_t
+ * ctypedef npy_int64 int64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_int96 int96_t
+ * #ctypedef npy_int128 int128_t
+ */
+typedef npy_int64 __pyx_t_5numpy_int64_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":738
+ * #ctypedef npy_int128 int128_t
+ *
+ * ctypedef npy_uint8 uint8_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t
+ */
+typedef npy_uint8 __pyx_t_5numpy_uint8_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":739
+ *
+ * ctypedef npy_uint8 uint8_t
+ * ctypedef npy_uint16 uint16_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint32 uint32_t
+ * ctypedef npy_uint64 uint64_t
+ */
+typedef npy_uint16 __pyx_t_5numpy_uint16_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":740
+ * ctypedef npy_uint8 uint8_t
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uint64 uint64_t
+ * #ctypedef npy_uint96 uint96_t
+ */
+typedef npy_uint32 __pyx_t_5numpy_uint32_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":741
+ * ctypedef npy_uint16 uint16_t
+ * ctypedef npy_uint32 uint32_t
+ * ctypedef npy_uint64 uint64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_uint96 uint96_t
+ * #ctypedef npy_uint128 uint128_t
+ */
+typedef npy_uint64 __pyx_t_5numpy_uint64_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":745
+ * #ctypedef npy_uint128 uint128_t
+ *
+ * ctypedef npy_float32 float32_t # <<<<<<<<<<<<<<
+ * ctypedef npy_float64 float64_t
+ * #ctypedef npy_float80 float80_t
+ */
+typedef npy_float32 __pyx_t_5numpy_float32_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":746
+ *
+ * ctypedef npy_float32 float32_t
+ * ctypedef npy_float64 float64_t # <<<<<<<<<<<<<<
+ * #ctypedef npy_float80 float80_t
+ * #ctypedef npy_float128 float128_t
+ */
+typedef npy_float64 __pyx_t_5numpy_float64_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":755
+ * # The int types are mapped a bit surprising --
+ * # numpy.int corresponds to 'l' and numpy.long to 'q'
+ * ctypedef npy_long int_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longlong long_t
+ * ctypedef npy_longlong longlong_t
+ */
+typedef npy_long __pyx_t_5numpy_int_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":756
+ * # numpy.int corresponds to 'l' and numpy.long to 'q'
+ * ctypedef npy_long int_t
+ * ctypedef npy_longlong long_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longlong longlong_t
+ *
+ */
+typedef npy_longlong __pyx_t_5numpy_long_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":757
+ * ctypedef npy_long int_t
+ * ctypedef npy_longlong long_t
+ * ctypedef npy_longlong longlong_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_ulong uint_t
+ */
+typedef npy_longlong __pyx_t_5numpy_longlong_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":759
+ * ctypedef npy_longlong longlong_t
+ *
+ * ctypedef npy_ulong uint_t # <<<<<<<<<<<<<<
+ * ctypedef npy_ulonglong ulong_t
+ * ctypedef npy_ulonglong ulonglong_t
+ */
+typedef npy_ulong __pyx_t_5numpy_uint_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":760
+ *
+ * ctypedef npy_ulong uint_t
+ * ctypedef npy_ulonglong ulong_t # <<<<<<<<<<<<<<
+ * ctypedef npy_ulonglong ulonglong_t
+ *
+ */
+typedef npy_ulonglong __pyx_t_5numpy_ulong_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":761
+ * ctypedef npy_ulong uint_t
+ * ctypedef npy_ulonglong ulong_t
+ * ctypedef npy_ulonglong ulonglong_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_intp intp_t
+ */
+typedef npy_ulonglong __pyx_t_5numpy_ulonglong_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":763
+ * ctypedef npy_ulonglong ulonglong_t
+ *
+ * ctypedef npy_intp intp_t # <<<<<<<<<<<<<<
+ * ctypedef npy_uintp uintp_t
+ *
+ */
+typedef npy_intp __pyx_t_5numpy_intp_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":764
+ *
+ * ctypedef npy_intp intp_t
+ * ctypedef npy_uintp uintp_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_double float_t
+ */
+typedef npy_uintp __pyx_t_5numpy_uintp_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":766
+ * ctypedef npy_uintp uintp_t
+ *
+ * ctypedef npy_double float_t # <<<<<<<<<<<<<<
+ * ctypedef npy_double double_t
+ * ctypedef npy_longdouble longdouble_t
+ */
+typedef npy_double __pyx_t_5numpy_float_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":767
+ *
+ * ctypedef npy_double float_t
+ * ctypedef npy_double double_t # <<<<<<<<<<<<<<
+ * ctypedef npy_longdouble longdouble_t
+ *
+ */
+typedef npy_double __pyx_t_5numpy_double_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":768
+ * ctypedef npy_double float_t
+ * ctypedef npy_double double_t
+ * ctypedef npy_longdouble longdouble_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_cfloat cfloat_t
+ */
+typedef npy_longdouble __pyx_t_5numpy_longdouble_t;
+/* #### Code section: complex_type_declarations ### */
+/* Declarations.proto */
+#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #ifdef __cplusplus
+ typedef ::std::complex< float > __pyx_t_float_complex;
+ #else
+ typedef float _Complex __pyx_t_float_complex;
+ #endif
+#else
+ typedef struct { float real, imag; } __pyx_t_float_complex;
+#endif
+static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float, float);
+
+/* Declarations.proto */
+#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #ifdef __cplusplus
+ typedef ::std::complex< double > __pyx_t_double_complex;
+ #else
+ typedef double _Complex __pyx_t_double_complex;
+ #endif
+#else
+ typedef struct { double real, imag; } __pyx_t_double_complex;
+#endif
+static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double, double);
+
+/* #### Code section: type_declarations ### */
+
+/*--- Type declarations ---*/
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":770
+ * ctypedef npy_longdouble longdouble_t
+ *
+ * ctypedef npy_cfloat cfloat_t # <<<<<<<<<<<<<<
+ * ctypedef npy_cdouble cdouble_t
+ * ctypedef npy_clongdouble clongdouble_t
+ */
+typedef npy_cfloat __pyx_t_5numpy_cfloat_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":771
+ *
+ * ctypedef npy_cfloat cfloat_t
+ * ctypedef npy_cdouble cdouble_t # <<<<<<<<<<<<<<
+ * ctypedef npy_clongdouble clongdouble_t
+ *
+ */
+typedef npy_cdouble __pyx_t_5numpy_cdouble_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":772
+ * ctypedef npy_cfloat cfloat_t
+ * ctypedef npy_cdouble cdouble_t
+ * ctypedef npy_clongdouble clongdouble_t # <<<<<<<<<<<<<<
+ *
+ * ctypedef npy_cdouble complex_t
+ */
+typedef npy_clongdouble __pyx_t_5numpy_clongdouble_t;
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":774
+ * ctypedef npy_clongdouble clongdouble_t
+ *
+ * ctypedef npy_cdouble complex_t # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew1(a):
+ */
+typedef npy_cdouble __pyx_t_5numpy_complex_t;
+/* #### Code section: utility_code_proto ### */
+
+/* --- Runtime support code (head) --- */
+/* Refnanny.proto */
+#ifndef CYTHON_REFNANNY
+ #define CYTHON_REFNANNY 0
+#endif
+#if CYTHON_REFNANNY
+ typedef struct {
+ void (*INCREF)(void*, PyObject*, Py_ssize_t);
+ void (*DECREF)(void*, PyObject*, Py_ssize_t);
+ void (*GOTREF)(void*, PyObject*, Py_ssize_t);
+ void (*GIVEREF)(void*, PyObject*, Py_ssize_t);
+ void* (*SetupContext)(const char*, Py_ssize_t, const char*);
+ void (*FinishContext)(void**);
+ } __Pyx_RefNannyAPIStruct;
+ static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL;
+ static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname);
+ #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL;
+#ifdef WITH_THREAD
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)\
+ if (acquire_gil) {\
+ PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\
+ PyGILState_Release(__pyx_gilstate_save);\
+ } else {\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\
+ }
+ #define __Pyx_RefNannyFinishContextNogil() {\
+ PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\
+ __Pyx_RefNannyFinishContext();\
+ PyGILState_Release(__pyx_gilstate_save);\
+ }
+#else
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)\
+ __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__))
+ #define __Pyx_RefNannyFinishContextNogil() __Pyx_RefNannyFinishContext()
+#endif
+ #define __Pyx_RefNannyFinishContextNogil() {\
+ PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\
+ __Pyx_RefNannyFinishContext();\
+ PyGILState_Release(__pyx_gilstate_save);\
+ }
+ #define __Pyx_RefNannyFinishContext()\
+ __Pyx_RefNanny->FinishContext(&__pyx_refnanny)
+ #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), (__LINE__))
+ #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), (__LINE__))
+ #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), (__LINE__))
+ #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), (__LINE__))
+ #define __Pyx_XINCREF(r) do { if((r) == NULL); else {__Pyx_INCREF(r); }} while(0)
+ #define __Pyx_XDECREF(r) do { if((r) == NULL); else {__Pyx_DECREF(r); }} while(0)
+ #define __Pyx_XGOTREF(r) do { if((r) == NULL); else {__Pyx_GOTREF(r); }} while(0)
+ #define __Pyx_XGIVEREF(r) do { if((r) == NULL); else {__Pyx_GIVEREF(r);}} while(0)
+#else
+ #define __Pyx_RefNannyDeclarations
+ #define __Pyx_RefNannySetupContext(name, acquire_gil)
+ #define __Pyx_RefNannyFinishContextNogil()
+ #define __Pyx_RefNannyFinishContext()
+ #define __Pyx_INCREF(r) Py_INCREF(r)
+ #define __Pyx_DECREF(r) Py_DECREF(r)
+ #define __Pyx_GOTREF(r)
+ #define __Pyx_GIVEREF(r)
+ #define __Pyx_XINCREF(r) Py_XINCREF(r)
+ #define __Pyx_XDECREF(r) Py_XDECREF(r)
+ #define __Pyx_XGOTREF(r)
+ #define __Pyx_XGIVEREF(r)
+#endif
+#define __Pyx_Py_XDECREF_SET(r, v) do {\
+ PyObject *tmp = (PyObject *) r;\
+ r = v; Py_XDECREF(tmp);\
+ } while (0)
+#define __Pyx_XDECREF_SET(r, v) do {\
+ PyObject *tmp = (PyObject *) r;\
+ r = v; __Pyx_XDECREF(tmp);\
+ } while (0)
+#define __Pyx_DECREF_SET(r, v) do {\
+ PyObject *tmp = (PyObject *) r;\
+ r = v; __Pyx_DECREF(tmp);\
+ } while (0)
+#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0)
+#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0)
+
+/* PyErrExceptionMatches.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err)
+static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err);
+#else
+#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err)
+#endif
+
+/* PyThreadStateGet.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate;
+#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current;
+#if PY_VERSION_HEX >= 0x030C00A6
+#define __Pyx_PyErr_Occurred() (__pyx_tstate->current_exception != NULL)
+#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->current_exception ? (PyObject*) Py_TYPE(__pyx_tstate->current_exception) : (PyObject*) NULL)
+#else
+#define __Pyx_PyErr_Occurred() (__pyx_tstate->curexc_type != NULL)
+#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->curexc_type)
+#endif
+#else
+#define __Pyx_PyThreadState_declare
+#define __Pyx_PyThreadState_assign
+#define __Pyx_PyErr_Occurred() (PyErr_Occurred() != NULL)
+#define __Pyx_PyErr_CurrentExceptionType() PyErr_Occurred()
+#endif
+
+/* PyErrFetchRestore.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL)
+#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb)
+#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb)
+#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb)
+#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb);
+static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A6
+#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL))
+#else
+#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc)
+#endif
+#else
+#define __Pyx_PyErr_Clear() PyErr_Clear()
+#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc)
+#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb)
+#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb)
+#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb)
+#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb)
+#endif
+
+/* PyObjectGetAttrStr.proto */
+#if CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name);
+#else
+#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n)
+#endif
+
+/* PyObjectGetAttrStrNoError.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name);
+
+/* GetBuiltinName.proto */
+static PyObject *__Pyx_GetBuiltinName(PyObject *name);
+
+/* GetTopmostException.proto */
+#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE
+static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate);
+#endif
+
+/* SaveResetException.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb)
+static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb);
+#else
+#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb)
+#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb)
+#endif
+
+/* GetException.proto */
+#if CYTHON_FAST_THREAD_STATE
+#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb)
+static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb);
+#else
+static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb);
+#endif
+
+/* PyObjectCall.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw);
+#else
+#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw)
+#endif
+
+/* RaiseException.proto */
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause);
+
+/* TupleAndListFromArray.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n);
+static CYTHON_INLINE PyObject* __Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n);
+#endif
+
+/* IncludeStringH.proto */
+#include
+
+/* BytesEquals.proto */
+static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals);
+
+/* UnicodeEquals.proto */
+static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals);
+
+/* fastcall.proto */
+#if CYTHON_AVOID_BORROWED_REFS
+ #define __Pyx_Arg_VARARGS(args, i) PySequence_GetItem(args, i)
+#elif CYTHON_ASSUME_SAFE_MACROS
+ #define __Pyx_Arg_VARARGS(args, i) PyTuple_GET_ITEM(args, i)
+#else
+ #define __Pyx_Arg_VARARGS(args, i) PyTuple_GetItem(args, i)
+#endif
+#if CYTHON_AVOID_BORROWED_REFS
+ #define __Pyx_Arg_NewRef_VARARGS(arg) __Pyx_NewRef(arg)
+ #define __Pyx_Arg_XDECREF_VARARGS(arg) Py_XDECREF(arg)
+#else
+ #define __Pyx_Arg_NewRef_VARARGS(arg) arg
+ #define __Pyx_Arg_XDECREF_VARARGS(arg)
+#endif
+#define __Pyx_NumKwargs_VARARGS(kwds) PyDict_Size(kwds)
+#define __Pyx_KwValues_VARARGS(args, nargs) NULL
+#define __Pyx_GetKwValue_VARARGS(kw, kwvalues, s) __Pyx_PyDict_GetItemStrWithError(kw, s)
+#define __Pyx_KwargsAsDict_VARARGS(kw, kwvalues) PyDict_Copy(kw)
+#if CYTHON_METH_FASTCALL
+ #define __Pyx_Arg_FASTCALL(args, i) args[i]
+ #define __Pyx_NumKwargs_FASTCALL(kwds) PyTuple_GET_SIZE(kwds)
+ #define __Pyx_KwValues_FASTCALL(args, nargs) ((args) + (nargs))
+ static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s);
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000
+ CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues);
+ #else
+ #define __Pyx_KwargsAsDict_FASTCALL(kw, kwvalues) _PyStack_AsDict(kwvalues, kw)
+ #endif
+ #define __Pyx_Arg_NewRef_FASTCALL(arg) arg /* no-op, __Pyx_Arg_FASTCALL is direct and this needs
+ to have the same reference counting */
+ #define __Pyx_Arg_XDECREF_FASTCALL(arg)
+#else
+ #define __Pyx_Arg_FASTCALL __Pyx_Arg_VARARGS
+ #define __Pyx_NumKwargs_FASTCALL __Pyx_NumKwargs_VARARGS
+ #define __Pyx_KwValues_FASTCALL __Pyx_KwValues_VARARGS
+ #define __Pyx_GetKwValue_FASTCALL __Pyx_GetKwValue_VARARGS
+ #define __Pyx_KwargsAsDict_FASTCALL __Pyx_KwargsAsDict_VARARGS
+ #define __Pyx_Arg_NewRef_FASTCALL(arg) __Pyx_Arg_NewRef_VARARGS(arg)
+ #define __Pyx_Arg_XDECREF_FASTCALL(arg) __Pyx_Arg_XDECREF_VARARGS(arg)
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+#define __Pyx_ArgsSlice_VARARGS(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_VARARGS(args, start), stop - start)
+#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_FASTCALL(args, start), stop - start)
+#else
+#define __Pyx_ArgsSlice_VARARGS(args, start, stop) PyTuple_GetSlice(args, start, stop)
+#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) PyTuple_GetSlice(args, start, stop)
+#endif
+
+/* RaiseArgTupleInvalid.proto */
+static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact,
+ Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found);
+
+/* RaiseDoubleKeywords.proto */
+static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name);
+
+/* ParseKeywords.proto */
+static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject *const *kwvalues,
+ PyObject **argnames[],
+ PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args,
+ const char* function_name);
+
+/* ArgTypeTest.proto */
+#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\
+ ((likely(__Pyx_IS_TYPE(obj, type) | (none_allowed && (obj == Py_None)))) ? 1 :\
+ __Pyx__ArgTypeTest(obj, type, name, exact))
+static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact);
+
+/* IsLittleEndian.proto */
+static CYTHON_INLINE int __Pyx_Is_Little_Endian(void);
+
+/* BufferFormatCheck.proto */
+static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts);
+static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx,
+ __Pyx_BufFmt_StackElem* stack,
+ __Pyx_TypeInfo* type);
+
+/* BufferGetAndValidate.proto */
+#define __Pyx_GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack)\
+ ((obj == Py_None || obj == NULL) ?\
+ (__Pyx_ZeroBuffer(buf), 0) :\
+ __Pyx__GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack))
+static int __Pyx__GetBufferAndValidate(Py_buffer* buf, PyObject* obj,
+ __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
+static void __Pyx_ZeroBuffer(Py_buffer* buf);
+static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
+static Py_ssize_t __Pyx_minusones[] = { -1, -1, -1, -1, -1, -1, -1, -1 };
+static Py_ssize_t __Pyx_zeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
+
+/* GetItemInt.proto */
+#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\
+ (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\
+ __Pyx_GetItemInt_Generic(o, to_py_func(i))))
+#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\
+ (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL))
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i,
+ int wraparound, int boundscheck);
+#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\
+ (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\
+ __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\
+ (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL))
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i,
+ int wraparound, int boundscheck);
+static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j);
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i,
+ int is_list, int wraparound, int boundscheck);
+
+/* PyFunctionFastCall.proto */
+#if CYTHON_FAST_PYCALL
+#if !CYTHON_VECTORCALL
+#define __Pyx_PyFunction_FastCall(func, args, nargs)\
+ __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL)
+static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs);
+#endif
+#define __Pyx_BUILD_ASSERT_EXPR(cond)\
+ (sizeof(char [1 - 2*!(cond)]) - 1)
+#ifndef Py_MEMBER_SIZE
+#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member)
+#endif
+#if !CYTHON_VECTORCALL
+#if PY_VERSION_HEX >= 0x03080000
+ #include "frameobject.h"
+#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API && !defined(PYPY_VERSION)
+ #ifndef Py_BUILD_CORE
+ #define Py_BUILD_CORE 1
+ #endif
+ #include "internal/pycore_frame.h"
+#endif
+ #define __Pxy_PyFrame_Initialize_Offsets()
+ #define __Pyx_PyFrame_GetLocalsplus(frame) ((frame)->f_localsplus)
+#else
+ static size_t __pyx_pyframe_localsplus_offset = 0;
+ #include "frameobject.h"
+ #define __Pxy_PyFrame_Initialize_Offsets()\
+ ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\
+ (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus)))
+ #define __Pyx_PyFrame_GetLocalsplus(frame)\
+ (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset))
+#endif
+#endif
+#endif
+
+/* PyObjectCallMethO.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg);
+#endif
+
+/* PyObjectFastCall.proto */
+#define __Pyx_PyObject_FastCall(func, args, nargs) __Pyx_PyObject_FastCallDict(func, args, (size_t)(nargs), NULL)
+static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs);
+
+/* PyObjectCallOneArg.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg);
+
+/* ObjectGetItem.proto */
+#if CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key);
+#else
+#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key)
+#endif
+
+/* ExtTypeTest.proto */
+static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type);
+
+/* PyIntBinop.proto */
+#if !CYTHON_COMPILING_IN_PYPY
+static PyObject* __Pyx_PyInt_AddObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check);
+#else
+#define __Pyx_PyInt_AddObjC(op1, op2, intval, inplace, zerodivision_check)\
+ (inplace ? PyNumber_InPlaceAdd(op1, op2) : PyNumber_Add(op1, op2))
+#endif
+
+/* PyDictVersioning.proto */
+#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS
+#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1)
+#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag)
+#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\
+ (version_var) = __PYX_GET_DICT_VERSION(dict);\
+ (cache_var) = (value);
+#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\
+ static PY_UINT64_T __pyx_dict_version = 0;\
+ static PyObject *__pyx_dict_cached_value = NULL;\
+ if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\
+ (VAR) = __pyx_dict_cached_value;\
+ } else {\
+ (VAR) = __pyx_dict_cached_value = (LOOKUP);\
+ __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\
+ }\
+}
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj);
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj);
+static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version);
+#else
+#define __PYX_GET_DICT_VERSION(dict) (0)
+#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)
+#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP);
+#endif
+
+/* GetModuleGlobalName.proto */
+#if CYTHON_USE_DICT_VERSIONS
+#define __Pyx_GetModuleGlobalName(var, name) do {\
+ static PY_UINT64_T __pyx_dict_version = 0;\
+ static PyObject *__pyx_dict_cached_value = NULL;\
+ (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\
+ (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\
+ __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\
+} while(0)
+#define __Pyx_GetModuleGlobalNameUncached(var, name) do {\
+ PY_UINT64_T __pyx_dict_version;\
+ PyObject *__pyx_dict_cached_value;\
+ (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\
+} while(0)
+static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value);
+#else
+#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name)
+#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name)
+static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name);
+#endif
+
+/* BufferIndexError.proto */
+static void __Pyx_RaiseBufferIndexError(int axis);
+
+#define __Pyx_BufPtrStrided1d(type, buf, i0, s0) (type)((char*)buf + i0 * s0)
+/* ListAppend.proto */
+#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS
+static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) {
+ PyListObject* L = (PyListObject*) list;
+ Py_ssize_t len = Py_SIZE(list);
+ if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) {
+ Py_INCREF(x);
+ #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000
+ L->ob_item[len] = x;
+ #else
+ PyList_SET_ITEM(list, len, x);
+ #endif
+ __Pyx_SET_SIZE(list, len + 1);
+ return 0;
+ }
+ return PyList_Append(list, x);
+}
+#else
+#define __Pyx_PyList_Append(L,x) PyList_Append(L,x)
+#endif
+
+#define __Pyx_BufPtrStrided2d(type, buf, i0, s0, i1, s1) (type)((char*)buf + i0 * s0 + i1 * s1)
+/* ListCompAppend.proto */
+#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS
+static CYTHON_INLINE int __Pyx_ListComp_Append(PyObject* list, PyObject* x) {
+ PyListObject* L = (PyListObject*) list;
+ Py_ssize_t len = Py_SIZE(list);
+ if (likely(L->allocated > len)) {
+ Py_INCREF(x);
+ #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000
+ L->ob_item[len] = x;
+ #else
+ PyList_SET_ITEM(list, len, x);
+ #endif
+ __Pyx_SET_SIZE(list, len + 1);
+ return 0;
+ }
+ return PyList_Append(list, x);
+}
+#else
+#define __Pyx_ListComp_Append(L,x) PyList_Append(L,x)
+#endif
+
+/* TypeImport.proto */
+#ifndef __PYX_HAVE_RT_ImportType_proto_3_0_12
+#define __PYX_HAVE_RT_ImportType_proto_3_0_12
+#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L
+#include
+#endif
+#if (defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || __cplusplus >= 201103L
+#define __PYX_GET_STRUCT_ALIGNMENT_3_0_12(s) alignof(s)
+#else
+#define __PYX_GET_STRUCT_ALIGNMENT_3_0_12(s) sizeof(void*)
+#endif
+enum __Pyx_ImportType_CheckSize_3_0_12 {
+ __Pyx_ImportType_CheckSize_Error_3_0_12 = 0,
+ __Pyx_ImportType_CheckSize_Warn_3_0_12 = 1,
+ __Pyx_ImportType_CheckSize_Ignore_3_0_12 = 2
+};
+static PyTypeObject *__Pyx_ImportType_3_0_12(PyObject* module, const char *module_name, const char *class_name, size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_12 check_size);
+#endif
+
+/* Import.proto */
+static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level);
+
+/* ImportDottedModule.proto */
+static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple);
+#if PY_MAJOR_VERSION >= 3
+static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple);
+#endif
+
+/* IncludeStructmemberH.proto */
+#include
+
+/* FixUpExtensionType.proto */
+#if CYTHON_USE_TYPE_SPECS
+static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type);
+#endif
+
+/* FetchSharedCythonModule.proto */
+static PyObject *__Pyx_FetchSharedCythonABIModule(void);
+
+/* FetchCommonType.proto */
+#if !CYTHON_USE_TYPE_SPECS
+static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type);
+#else
+static PyTypeObject* __Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases);
+#endif
+
+/* PyMethodNew.proto */
+#if CYTHON_COMPILING_IN_LIMITED_API
+static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) {
+ PyObject *typesModule=NULL, *methodType=NULL, *result=NULL;
+ CYTHON_UNUSED_VAR(typ);
+ if (!self)
+ return __Pyx_NewRef(func);
+ typesModule = PyImport_ImportModule("types");
+ if (!typesModule) return NULL;
+ methodType = PyObject_GetAttrString(typesModule, "MethodType");
+ Py_DECREF(typesModule);
+ if (!methodType) return NULL;
+ result = PyObject_CallFunctionObjArgs(methodType, func, self, NULL);
+ Py_DECREF(methodType);
+ return result;
+}
+#elif PY_MAJOR_VERSION >= 3
+static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) {
+ CYTHON_UNUSED_VAR(typ);
+ if (!self)
+ return __Pyx_NewRef(func);
+ return PyMethod_New(func, self);
+}
+#else
+ #define __Pyx_PyMethod_New PyMethod_New
+#endif
+
+/* PyVectorcallFastCallDict.proto */
+#if CYTHON_METH_FASTCALL
+static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw);
+#endif
+
+/* CythonFunctionShared.proto */
+#define __Pyx_CyFunction_USED
+#define __Pyx_CYFUNCTION_STATICMETHOD 0x01
+#define __Pyx_CYFUNCTION_CLASSMETHOD 0x02
+#define __Pyx_CYFUNCTION_CCLASS 0x04
+#define __Pyx_CYFUNCTION_COROUTINE 0x08
+#define __Pyx_CyFunction_GetClosure(f)\
+ (((__pyx_CyFunctionObject *) (f))->func_closure)
+#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API
+ #define __Pyx_CyFunction_GetClassObj(f)\
+ (((__pyx_CyFunctionObject *) (f))->func_classobj)
+#else
+ #define __Pyx_CyFunction_GetClassObj(f)\
+ ((PyObject*) ((PyCMethodObject *) (f))->mm_class)
+#endif
+#define __Pyx_CyFunction_SetClassObj(f, classobj)\
+ __Pyx__CyFunction_SetClassObj((__pyx_CyFunctionObject *) (f), (classobj))
+#define __Pyx_CyFunction_Defaults(type, f)\
+ ((type *)(((__pyx_CyFunctionObject *) (f))->defaults))
+#define __Pyx_CyFunction_SetDefaultsGetter(f, g)\
+ ((__pyx_CyFunctionObject *) (f))->defaults_getter = (g)
+typedef struct {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ PyObject_HEAD
+ PyObject *func;
+#elif PY_VERSION_HEX < 0x030900B1
+ PyCFunctionObject func;
+#else
+ PyCMethodObject func;
+#endif
+#if CYTHON_BACKPORT_VECTORCALL
+ __pyx_vectorcallfunc func_vectorcall;
+#endif
+#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API
+ PyObject *func_weakreflist;
+#endif
+ PyObject *func_dict;
+ PyObject *func_name;
+ PyObject *func_qualname;
+ PyObject *func_doc;
+ PyObject *func_globals;
+ PyObject *func_code;
+ PyObject *func_closure;
+#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API
+ PyObject *func_classobj;
+#endif
+ void *defaults;
+ int defaults_pyobjects;
+ size_t defaults_size;
+ int flags;
+ PyObject *defaults_tuple;
+ PyObject *defaults_kwdict;
+ PyObject *(*defaults_getter)(PyObject *);
+ PyObject *func_annotations;
+ PyObject *func_is_coroutine;
+} __pyx_CyFunctionObject;
+#undef __Pyx_CyOrPyCFunction_Check
+#define __Pyx_CyFunction_Check(obj) __Pyx_TypeCheck(obj, __pyx_CyFunctionType)
+#define __Pyx_CyOrPyCFunction_Check(obj) __Pyx_TypeCheck2(obj, __pyx_CyFunctionType, &PyCFunction_Type)
+#define __Pyx_CyFunction_CheckExact(obj) __Pyx_IS_TYPE(obj, __pyx_CyFunctionType)
+static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc);
+#undef __Pyx_IsSameCFunction
+#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCyOrCFunction(func, cfunc)
+static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject* op, PyMethodDef *ml,
+ int flags, PyObject* qualname,
+ PyObject *closure,
+ PyObject *module, PyObject *globals,
+ PyObject* code);
+static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj);
+static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m,
+ size_t size,
+ int pyobjects);
+static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m,
+ PyObject *tuple);
+static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *m,
+ PyObject *dict);
+static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m,
+ PyObject *dict);
+static int __pyx_CyFunction_init(PyObject *module);
+#if CYTHON_METH_FASTCALL
+static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames);
+static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames);
+static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames);
+static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames);
+#if CYTHON_BACKPORT_VECTORCALL
+#define __Pyx_CyFunction_func_vectorcall(f) (((__pyx_CyFunctionObject*)f)->func_vectorcall)
+#else
+#define __Pyx_CyFunction_func_vectorcall(f) (((PyCFunctionObject*)f)->vectorcall)
+#endif
+#endif
+
+/* CythonFunction.proto */
+static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml,
+ int flags, PyObject* qualname,
+ PyObject *closure,
+ PyObject *module, PyObject *globals,
+ PyObject* code);
+
+/* CLineInTraceback.proto */
+#ifdef CYTHON_CLINE_IN_TRACEBACK
+#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0)
+#else
+static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line);
+#endif
+
+/* CodeObjectCache.proto */
+#if !CYTHON_COMPILING_IN_LIMITED_API
+typedef struct {
+ PyCodeObject* code_object;
+ int code_line;
+} __Pyx_CodeObjectCacheEntry;
+struct __Pyx_CodeObjectCache {
+ int count;
+ int max_count;
+ __Pyx_CodeObjectCacheEntry* entries;
+};
+static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL};
+static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line);
+static PyCodeObject *__pyx_find_code_object(int code_line);
+static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object);
+#endif
+
+/* AddTraceback.proto */
+static void __Pyx_AddTraceback(const char *funcname, int c_line,
+ int py_line, const char *filename);
+
+/* BufferStructDeclare.proto */
+typedef struct {
+ Py_ssize_t shape, strides, suboffsets;
+} __Pyx_Buf_DimInfo;
+typedef struct {
+ size_t refcount;
+ Py_buffer pybuffer;
+} __Pyx_Buffer;
+typedef struct {
+ __Pyx_Buffer *rcbuffer;
+ char *data;
+ __Pyx_Buf_DimInfo diminfo[8];
+} __Pyx_LocalBuf_ND;
+
+#if PY_MAJOR_VERSION < 3
+ static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
+ static void __Pyx_ReleaseBuffer(Py_buffer *view);
+#else
+ #define __Pyx_GetBuffer PyObject_GetBuffer
+ #define __Pyx_ReleaseBuffer PyBuffer_Release
+#endif
+
+
+/* GCCDiagnostics.proto */
+#if !defined(__INTEL_COMPILER) && defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
+#define __Pyx_HAS_GCC_DIAGNOSTIC
+#endif
+
+/* RealImag.proto */
+#if CYTHON_CCOMPLEX
+ #ifdef __cplusplus
+ #define __Pyx_CREAL(z) ((z).real())
+ #define __Pyx_CIMAG(z) ((z).imag())
+ #else
+ #define __Pyx_CREAL(z) (__real__(z))
+ #define __Pyx_CIMAG(z) (__imag__(z))
+ #endif
+#else
+ #define __Pyx_CREAL(z) ((z).real)
+ #define __Pyx_CIMAG(z) ((z).imag)
+#endif
+#if defined(__cplusplus) && CYTHON_CCOMPLEX\
+ && (defined(_WIN32) || defined(__clang__) || (defined(__GNUC__) && (__GNUC__ >= 5 || __GNUC__ == 4 && __GNUC_MINOR__ >= 4 )) || __cplusplus >= 201103)
+ #define __Pyx_SET_CREAL(z,x) ((z).real(x))
+ #define __Pyx_SET_CIMAG(z,y) ((z).imag(y))
+#else
+ #define __Pyx_SET_CREAL(z,x) __Pyx_CREAL(z) = (x)
+ #define __Pyx_SET_CIMAG(z,y) __Pyx_CIMAG(z) = (y)
+#endif
+
+/* Arithmetic.proto */
+#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #define __Pyx_c_eq_float(a, b) ((a)==(b))
+ #define __Pyx_c_sum_float(a, b) ((a)+(b))
+ #define __Pyx_c_diff_float(a, b) ((a)-(b))
+ #define __Pyx_c_prod_float(a, b) ((a)*(b))
+ #define __Pyx_c_quot_float(a, b) ((a)/(b))
+ #define __Pyx_c_neg_float(a) (-(a))
+ #ifdef __cplusplus
+ #define __Pyx_c_is_zero_float(z) ((z)==(float)0)
+ #define __Pyx_c_conj_float(z) (::std::conj(z))
+ #if 1
+ #define __Pyx_c_abs_float(z) (::std::abs(z))
+ #define __Pyx_c_pow_float(a, b) (::std::pow(a, b))
+ #endif
+ #else
+ #define __Pyx_c_is_zero_float(z) ((z)==0)
+ #define __Pyx_c_conj_float(z) (conjf(z))
+ #if 1
+ #define __Pyx_c_abs_float(z) (cabsf(z))
+ #define __Pyx_c_pow_float(a, b) (cpowf(a, b))
+ #endif
+ #endif
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex);
+ static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex);
+ #if 1
+ static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex);
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex, __pyx_t_float_complex);
+ #endif
+#endif
+
+/* Arithmetic.proto */
+#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #define __Pyx_c_eq_double(a, b) ((a)==(b))
+ #define __Pyx_c_sum_double(a, b) ((a)+(b))
+ #define __Pyx_c_diff_double(a, b) ((a)-(b))
+ #define __Pyx_c_prod_double(a, b) ((a)*(b))
+ #define __Pyx_c_quot_double(a, b) ((a)/(b))
+ #define __Pyx_c_neg_double(a) (-(a))
+ #ifdef __cplusplus
+ #define __Pyx_c_is_zero_double(z) ((z)==(double)0)
+ #define __Pyx_c_conj_double(z) (::std::conj(z))
+ #if 1
+ #define __Pyx_c_abs_double(z) (::std::abs(z))
+ #define __Pyx_c_pow_double(a, b) (::std::pow(a, b))
+ #endif
+ #else
+ #define __Pyx_c_is_zero_double(z) ((z)==0)
+ #define __Pyx_c_conj_double(z) (conj(z))
+ #if 1
+ #define __Pyx_c_abs_double(z) (cabs(z))
+ #define __Pyx_c_pow_double(a, b) (cpow(a, b))
+ #endif
+ #endif
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex);
+ static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex);
+ #if 1
+ static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex);
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex, __pyx_t_double_complex);
+ #endif
+#endif
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE unsigned int __Pyx_PyInt_As_unsigned_int(PyObject *);
+
+/* CIntToPy.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyInt_From_unsigned_int(unsigned int value);
+
+/* CIntToPy.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value);
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *);
+
+/* CIntToPy.proto */
+static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value);
+
+/* FormatTypeName.proto */
+#if CYTHON_COMPILING_IN_LIMITED_API
+typedef PyObject *__Pyx_TypeName;
+#define __Pyx_FMT_TYPENAME "%U"
+static __Pyx_TypeName __Pyx_PyType_GetName(PyTypeObject* tp);
+#define __Pyx_DECREF_TypeName(obj) Py_XDECREF(obj)
+#else
+typedef const char *__Pyx_TypeName;
+#define __Pyx_FMT_TYPENAME "%.200s"
+#define __Pyx_PyType_GetName(tp) ((tp)->tp_name)
+#define __Pyx_DECREF_TypeName(obj)
+#endif
+
+/* CIntFromPy.proto */
+static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *);
+
+/* FastTypeChecks.proto */
+#if CYTHON_COMPILING_IN_CPYTHON
+#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type)
+#define __Pyx_TypeCheck2(obj, type1, type2) __Pyx_IsAnySubtype2(Py_TYPE(obj), (PyTypeObject *)type1, (PyTypeObject *)type2)
+static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b);
+static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b);
+static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type);
+static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2);
+#else
+#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type)
+#define __Pyx_TypeCheck2(obj, type1, type2) (PyObject_TypeCheck(obj, (PyTypeObject *)type1) || PyObject_TypeCheck(obj, (PyTypeObject *)type2))
+#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type)
+#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2))
+#endif
+#define __Pyx_PyErr_ExceptionMatches2(err1, err2) __Pyx_PyErr_GivenExceptionMatches2(__Pyx_PyErr_CurrentExceptionType(), err1, err2)
+#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception)
+
+/* CheckBinaryVersion.proto */
+static unsigned long __Pyx_get_runtime_version(void);
+static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer);
+
+/* InitStrings.proto */
+static int __Pyx_InitStrings(__Pyx_StringTabEntry *t);
+
+/* #### Code section: module_declarations ### */
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self); /* proto*/
+static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self); /* proto*/
+
+/* Module declarations from "libc.string" */
+
+/* Module declarations from "libc.stdio" */
+
+/* Module declarations from "__builtin__" */
+
+/* Module declarations from "cpython.type" */
+
+/* Module declarations from "cpython" */
+
+/* Module declarations from "cpython.object" */
+
+/* Module declarations from "cpython.ref" */
+
+/* Module declarations from "numpy" */
+
+/* Module declarations from "numpy" */
+
+/* Module declarations from "nms.cpu_nms" */
+static CYTHON_INLINE __pyx_t_5numpy_float32_t __pyx_f_3nms_7cpu_nms_max(__pyx_t_5numpy_float32_t, __pyx_t_5numpy_float32_t); /*proto*/
+static CYTHON_INLINE __pyx_t_5numpy_float32_t __pyx_f_3nms_7cpu_nms_min(__pyx_t_5numpy_float32_t, __pyx_t_5numpy_float32_t); /*proto*/
+/* #### Code section: typeinfo ### */
+static __Pyx_TypeInfo __Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t = { "float32_t", NULL, sizeof(__pyx_t_5numpy_float32_t), { 0 }, 0, 'R', 0, 0 };
+static __Pyx_TypeInfo __Pyx_TypeInfo_nn___pyx_t_5numpy_int_t = { "int_t", NULL, sizeof(__pyx_t_5numpy_int_t), { 0 }, 0, __PYX_IS_UNSIGNED(__pyx_t_5numpy_int_t) ? 'U' : 'I', __PYX_IS_UNSIGNED(__pyx_t_5numpy_int_t), 0 };
+static __Pyx_TypeInfo __Pyx_TypeInfo_float = { "float", NULL, sizeof(float), { 0 }, 0, 'R', 0, 0 };
+/* #### Code section: before_global_var ### */
+#define __Pyx_MODULE_NAME "nms.cpu_nms"
+extern int __pyx_module_is_main_nms__cpu_nms;
+int __pyx_module_is_main_nms__cpu_nms = 0;
+
+/* Implementation of "nms.cpu_nms" */
+/* #### Code section: global_var ### */
+static PyObject *__pyx_builtin_range;
+static PyObject *__pyx_builtin_ImportError;
+/* #### Code section: string_decls ### */
+static const char __pyx_k_N[] = "N";
+static const char __pyx_k_h[] = "h";
+static const char __pyx_k_i[] = "_i";
+static const char __pyx_k_j[] = "_j";
+static const char __pyx_k_s[] = "s";
+static const char __pyx_k_w[] = "w";
+static const char __pyx_k_Nt[] = "Nt";
+static const char __pyx_k_ih[] = "ih";
+static const char __pyx_k_iw[] = "iw";
+static const char __pyx_k_np[] = "np";
+static const char __pyx_k_ov[] = "ov";
+static const char __pyx_k_ts[] = "ts";
+static const char __pyx_k_ua[] = "ua";
+static const char __pyx_k_x1[] = "x1";
+static const char __pyx_k_x2[] = "x2";
+static const char __pyx_k_y1[] = "y1";
+static const char __pyx_k_y2[] = "y2";
+static const char __pyx_k__10[] = "*";
+static const char __pyx_k__15[] = "?";
+static const char __pyx_k_exp[] = "exp";
+static const char __pyx_k_i_2[] = "i";
+static const char __pyx_k_int[] = "int";
+static const char __pyx_k_ix1[] = "ix1";
+static const char __pyx_k_ix2[] = "ix2";
+static const char __pyx_k_iy1[] = "iy1";
+static const char __pyx_k_iy2[] = "iy2";
+static const char __pyx_k_j_2[] = "j";
+static const char __pyx_k_ovr[] = "ovr";
+static const char __pyx_k_pos[] = "pos";
+static const char __pyx_k_tx1[] = "tx1";
+static const char __pyx_k_tx2[] = "tx2";
+static const char __pyx_k_ty1[] = "ty1";
+static const char __pyx_k_ty2[] = "ty2";
+static const char __pyx_k_xx1[] = "xx1";
+static const char __pyx_k_xx2[] = "xx2";
+static const char __pyx_k_yy1[] = "yy1";
+static const char __pyx_k_yy2[] = "yy2";
+static const char __pyx_k_area[] = "area";
+static const char __pyx_k_dets[] = "dets";
+static const char __pyx_k_keep[] = "keep";
+static const char __pyx_k_main[] = "__main__";
+static const char __pyx_k_name[] = "__name__";
+static const char __pyx_k_spec[] = "__spec__";
+static const char __pyx_k_test[] = "__test__";
+static const char __pyx_k_areas[] = "areas";
+static const char __pyx_k_boxes[] = "boxes";
+static const char __pyx_k_dtype[] = "dtype";
+static const char __pyx_k_iarea[] = "iarea";
+static const char __pyx_k_inter[] = "inter";
+static const char __pyx_k_ndets[] = "ndets";
+static const char __pyx_k_numpy[] = "numpy";
+static const char __pyx_k_order[] = "order";
+static const char __pyx_k_range[] = "range";
+static const char __pyx_k_sigma[] = "sigma";
+static const char __pyx_k_zeros[] = "zeros";
+static const char __pyx_k_import[] = "__import__";
+static const char __pyx_k_maxpos[] = "maxpos";
+static const char __pyx_k_method[] = "method";
+static const char __pyx_k_scores[] = "scores";
+static const char __pyx_k_thresh[] = "thresh";
+static const char __pyx_k_weight[] = "weight";
+static const char __pyx_k_argsort[] = "argsort";
+static const char __pyx_k_cpu_nms[] = "cpu_nms";
+static const char __pyx_k_box_area[] = "box_area";
+static const char __pyx_k_maxscore[] = "maxscore";
+static const char __pyx_k_threshold[] = "threshold";
+static const char __pyx_k_suppressed[] = "suppressed";
+static const char __pyx_k_ImportError[] = "ImportError";
+static const char __pyx_k_nms_cpu_nms[] = "nms.cpu_nms";
+static const char __pyx_k_cpu_soft_nms[] = "cpu_soft_nms";
+static const char __pyx_k_initializing[] = "_initializing";
+static const char __pyx_k_is_coroutine[] = "_is_coroutine";
+static const char __pyx_k_class_getitem[] = "__class_getitem__";
+static const char __pyx_k_nms_cpu_nms_pyx[] = "nms/cpu_nms.pyx";
+static const char __pyx_k_asyncio_coroutines[] = "asyncio.coroutines";
+static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback";
+static const char __pyx_k_numpy_core_multiarray_failed_to[] = "numpy.core.multiarray failed to import";
+static const char __pyx_k_numpy_core_umath_failed_to_impor[] = "numpy.core.umath failed to import";
+/* #### Code section: decls ### */
+static PyObject *__pyx_pf_3nms_7cpu_nms_cpu_nms(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_dets, PyObject *__pyx_v_thresh); /* proto */
+static PyObject *__pyx_pf_3nms_7cpu_nms_2cpu_soft_nms(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_boxes, float __pyx_v_sigma, float __pyx_v_Nt, float __pyx_v_threshold, unsigned int __pyx_v_method); /* proto */
+/* #### Code section: late_includes ### */
+/* #### Code section: module_state ### */
+typedef struct {
+ PyObject *__pyx_d;
+ PyObject *__pyx_b;
+ PyObject *__pyx_cython_runtime;
+ PyObject *__pyx_empty_tuple;
+ PyObject *__pyx_empty_bytes;
+ PyObject *__pyx_empty_unicode;
+ #ifdef __Pyx_CyFunction_USED
+ PyTypeObject *__pyx_CyFunctionType;
+ #endif
+ #ifdef __Pyx_FusedFunction_USED
+ PyTypeObject *__pyx_FusedFunctionType;
+ #endif
+ #ifdef __Pyx_Generator_USED
+ PyTypeObject *__pyx_GeneratorType;
+ #endif
+ #ifdef __Pyx_IterableCoroutine_USED
+ PyTypeObject *__pyx_IterableCoroutineType;
+ #endif
+ #ifdef __Pyx_Coroutine_USED
+ PyTypeObject *__pyx_CoroutineAwaitType;
+ #endif
+ #ifdef __Pyx_Coroutine_USED
+ PyTypeObject *__pyx_CoroutineType;
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ PyTypeObject *__pyx_ptype_7cpython_4type_type;
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ PyTypeObject *__pyx_ptype_5numpy_dtype;
+ PyTypeObject *__pyx_ptype_5numpy_flatiter;
+ PyTypeObject *__pyx_ptype_5numpy_broadcast;
+ PyTypeObject *__pyx_ptype_5numpy_ndarray;
+ PyTypeObject *__pyx_ptype_5numpy_generic;
+ PyTypeObject *__pyx_ptype_5numpy_number;
+ PyTypeObject *__pyx_ptype_5numpy_integer;
+ PyTypeObject *__pyx_ptype_5numpy_signedinteger;
+ PyTypeObject *__pyx_ptype_5numpy_unsignedinteger;
+ PyTypeObject *__pyx_ptype_5numpy_inexact;
+ PyTypeObject *__pyx_ptype_5numpy_floating;
+ PyTypeObject *__pyx_ptype_5numpy_complexfloating;
+ PyTypeObject *__pyx_ptype_5numpy_flexible;
+ PyTypeObject *__pyx_ptype_5numpy_character;
+ PyTypeObject *__pyx_ptype_5numpy_ufunc;
+ #if CYTHON_USE_MODULE_STATE
+ #endif
+ PyObject *__pyx_n_s_ImportError;
+ PyObject *__pyx_n_s_N;
+ PyObject *__pyx_n_s_Nt;
+ PyObject *__pyx_n_s__10;
+ PyObject *__pyx_n_s__15;
+ PyObject *__pyx_n_s_area;
+ PyObject *__pyx_n_s_areas;
+ PyObject *__pyx_n_s_argsort;
+ PyObject *__pyx_n_s_asyncio_coroutines;
+ PyObject *__pyx_n_s_box_area;
+ PyObject *__pyx_n_s_boxes;
+ PyObject *__pyx_n_s_class_getitem;
+ PyObject *__pyx_n_s_cline_in_traceback;
+ PyObject *__pyx_n_s_cpu_nms;
+ PyObject *__pyx_n_s_cpu_soft_nms;
+ PyObject *__pyx_n_s_dets;
+ PyObject *__pyx_n_s_dtype;
+ PyObject *__pyx_n_s_exp;
+ PyObject *__pyx_n_s_h;
+ PyObject *__pyx_n_s_i;
+ PyObject *__pyx_n_s_i_2;
+ PyObject *__pyx_n_s_iarea;
+ PyObject *__pyx_n_s_ih;
+ PyObject *__pyx_n_s_import;
+ PyObject *__pyx_n_s_initializing;
+ PyObject *__pyx_n_s_int;
+ PyObject *__pyx_n_s_inter;
+ PyObject *__pyx_n_s_is_coroutine;
+ PyObject *__pyx_n_s_iw;
+ PyObject *__pyx_n_s_ix1;
+ PyObject *__pyx_n_s_ix2;
+ PyObject *__pyx_n_s_iy1;
+ PyObject *__pyx_n_s_iy2;
+ PyObject *__pyx_n_s_j;
+ PyObject *__pyx_n_s_j_2;
+ PyObject *__pyx_n_s_keep;
+ PyObject *__pyx_n_s_main;
+ PyObject *__pyx_n_s_maxpos;
+ PyObject *__pyx_n_s_maxscore;
+ PyObject *__pyx_n_s_method;
+ PyObject *__pyx_n_s_name;
+ PyObject *__pyx_n_s_ndets;
+ PyObject *__pyx_n_s_nms_cpu_nms;
+ PyObject *__pyx_kp_s_nms_cpu_nms_pyx;
+ PyObject *__pyx_n_s_np;
+ PyObject *__pyx_n_s_numpy;
+ PyObject *__pyx_kp_s_numpy_core_multiarray_failed_to;
+ PyObject *__pyx_kp_s_numpy_core_umath_failed_to_impor;
+ PyObject *__pyx_n_s_order;
+ PyObject *__pyx_n_s_ov;
+ PyObject *__pyx_n_s_ovr;
+ PyObject *__pyx_n_s_pos;
+ PyObject *__pyx_n_s_range;
+ PyObject *__pyx_n_s_s;
+ PyObject *__pyx_n_s_scores;
+ PyObject *__pyx_n_s_sigma;
+ PyObject *__pyx_n_s_spec;
+ PyObject *__pyx_n_s_suppressed;
+ PyObject *__pyx_n_s_test;
+ PyObject *__pyx_n_s_thresh;
+ PyObject *__pyx_n_s_threshold;
+ PyObject *__pyx_n_s_ts;
+ PyObject *__pyx_n_s_tx1;
+ PyObject *__pyx_n_s_tx2;
+ PyObject *__pyx_n_s_ty1;
+ PyObject *__pyx_n_s_ty2;
+ PyObject *__pyx_n_s_ua;
+ PyObject *__pyx_n_s_w;
+ PyObject *__pyx_n_s_weight;
+ PyObject *__pyx_n_s_x1;
+ PyObject *__pyx_n_s_x2;
+ PyObject *__pyx_n_s_xx1;
+ PyObject *__pyx_n_s_xx2;
+ PyObject *__pyx_n_s_y1;
+ PyObject *__pyx_n_s_y2;
+ PyObject *__pyx_n_s_yy1;
+ PyObject *__pyx_n_s_yy2;
+ PyObject *__pyx_n_s_zeros;
+ PyObject *__pyx_int_0;
+ PyObject *__pyx_int_1;
+ PyObject *__pyx_int_2;
+ PyObject *__pyx_int_3;
+ PyObject *__pyx_int_4;
+ PyObject *__pyx_int_neg_1;
+ PyObject *__pyx_tuple_;
+ PyObject *__pyx_slice__3;
+ PyObject *__pyx_slice__9;
+ PyObject *__pyx_tuple__2;
+ PyObject *__pyx_tuple__4;
+ PyObject *__pyx_tuple__5;
+ PyObject *__pyx_tuple__6;
+ PyObject *__pyx_tuple__7;
+ PyObject *__pyx_tuple__8;
+ PyObject *__pyx_tuple__11;
+ PyObject *__pyx_tuple__13;
+ PyObject *__pyx_codeobj__12;
+ PyObject *__pyx_codeobj__14;
+} __pyx_mstate;
+
+#if CYTHON_USE_MODULE_STATE
+#ifdef __cplusplus
+namespace {
+ extern struct PyModuleDef __pyx_moduledef;
+} /* anonymous namespace */
+#else
+static struct PyModuleDef __pyx_moduledef;
+#endif
+
+#define __pyx_mstate(o) ((__pyx_mstate *)__Pyx_PyModule_GetState(o))
+
+#define __pyx_mstate_global (__pyx_mstate(PyState_FindModule(&__pyx_moduledef)))
+
+#define __pyx_m (PyState_FindModule(&__pyx_moduledef))
+#else
+static __pyx_mstate __pyx_mstate_global_static =
+#ifdef __cplusplus
+ {};
+#else
+ {0};
+#endif
+static __pyx_mstate *__pyx_mstate_global = &__pyx_mstate_global_static;
+#endif
+/* #### Code section: module_state_clear ### */
+#if CYTHON_USE_MODULE_STATE
+static int __pyx_m_clear(PyObject *m) {
+ __pyx_mstate *clear_module_state = __pyx_mstate(m);
+ if (!clear_module_state) return 0;
+ Py_CLEAR(clear_module_state->__pyx_d);
+ Py_CLEAR(clear_module_state->__pyx_b);
+ Py_CLEAR(clear_module_state->__pyx_cython_runtime);
+ Py_CLEAR(clear_module_state->__pyx_empty_tuple);
+ Py_CLEAR(clear_module_state->__pyx_empty_bytes);
+ Py_CLEAR(clear_module_state->__pyx_empty_unicode);
+ #ifdef __Pyx_CyFunction_USED
+ Py_CLEAR(clear_module_state->__pyx_CyFunctionType);
+ #endif
+ #ifdef __Pyx_FusedFunction_USED
+ Py_CLEAR(clear_module_state->__pyx_FusedFunctionType);
+ #endif
+ Py_CLEAR(clear_module_state->__pyx_ptype_7cpython_4type_type);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_dtype);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flatiter);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_broadcast);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ndarray);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_generic);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_number);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_integer);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_signedinteger);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_unsignedinteger);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_inexact);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_floating);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_complexfloating);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flexible);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_character);
+ Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ufunc);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ImportError);
+ Py_CLEAR(clear_module_state->__pyx_n_s_N);
+ Py_CLEAR(clear_module_state->__pyx_n_s_Nt);
+ Py_CLEAR(clear_module_state->__pyx_n_s__10);
+ Py_CLEAR(clear_module_state->__pyx_n_s__15);
+ Py_CLEAR(clear_module_state->__pyx_n_s_area);
+ Py_CLEAR(clear_module_state->__pyx_n_s_areas);
+ Py_CLEAR(clear_module_state->__pyx_n_s_argsort);
+ Py_CLEAR(clear_module_state->__pyx_n_s_asyncio_coroutines);
+ Py_CLEAR(clear_module_state->__pyx_n_s_box_area);
+ Py_CLEAR(clear_module_state->__pyx_n_s_boxes);
+ Py_CLEAR(clear_module_state->__pyx_n_s_class_getitem);
+ Py_CLEAR(clear_module_state->__pyx_n_s_cline_in_traceback);
+ Py_CLEAR(clear_module_state->__pyx_n_s_cpu_nms);
+ Py_CLEAR(clear_module_state->__pyx_n_s_cpu_soft_nms);
+ Py_CLEAR(clear_module_state->__pyx_n_s_dets);
+ Py_CLEAR(clear_module_state->__pyx_n_s_dtype);
+ Py_CLEAR(clear_module_state->__pyx_n_s_exp);
+ Py_CLEAR(clear_module_state->__pyx_n_s_h);
+ Py_CLEAR(clear_module_state->__pyx_n_s_i);
+ Py_CLEAR(clear_module_state->__pyx_n_s_i_2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_iarea);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ih);
+ Py_CLEAR(clear_module_state->__pyx_n_s_import);
+ Py_CLEAR(clear_module_state->__pyx_n_s_initializing);
+ Py_CLEAR(clear_module_state->__pyx_n_s_int);
+ Py_CLEAR(clear_module_state->__pyx_n_s_inter);
+ Py_CLEAR(clear_module_state->__pyx_n_s_is_coroutine);
+ Py_CLEAR(clear_module_state->__pyx_n_s_iw);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ix1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ix2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_iy1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_iy2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_j);
+ Py_CLEAR(clear_module_state->__pyx_n_s_j_2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_keep);
+ Py_CLEAR(clear_module_state->__pyx_n_s_main);
+ Py_CLEAR(clear_module_state->__pyx_n_s_maxpos);
+ Py_CLEAR(clear_module_state->__pyx_n_s_maxscore);
+ Py_CLEAR(clear_module_state->__pyx_n_s_method);
+ Py_CLEAR(clear_module_state->__pyx_n_s_name);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ndets);
+ Py_CLEAR(clear_module_state->__pyx_n_s_nms_cpu_nms);
+ Py_CLEAR(clear_module_state->__pyx_kp_s_nms_cpu_nms_pyx);
+ Py_CLEAR(clear_module_state->__pyx_n_s_np);
+ Py_CLEAR(clear_module_state->__pyx_n_s_numpy);
+ Py_CLEAR(clear_module_state->__pyx_kp_s_numpy_core_multiarray_failed_to);
+ Py_CLEAR(clear_module_state->__pyx_kp_s_numpy_core_umath_failed_to_impor);
+ Py_CLEAR(clear_module_state->__pyx_n_s_order);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ov);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ovr);
+ Py_CLEAR(clear_module_state->__pyx_n_s_pos);
+ Py_CLEAR(clear_module_state->__pyx_n_s_range);
+ Py_CLEAR(clear_module_state->__pyx_n_s_s);
+ Py_CLEAR(clear_module_state->__pyx_n_s_scores);
+ Py_CLEAR(clear_module_state->__pyx_n_s_sigma);
+ Py_CLEAR(clear_module_state->__pyx_n_s_spec);
+ Py_CLEAR(clear_module_state->__pyx_n_s_suppressed);
+ Py_CLEAR(clear_module_state->__pyx_n_s_test);
+ Py_CLEAR(clear_module_state->__pyx_n_s_thresh);
+ Py_CLEAR(clear_module_state->__pyx_n_s_threshold);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ts);
+ Py_CLEAR(clear_module_state->__pyx_n_s_tx1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_tx2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ty1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ty2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_ua);
+ Py_CLEAR(clear_module_state->__pyx_n_s_w);
+ Py_CLEAR(clear_module_state->__pyx_n_s_weight);
+ Py_CLEAR(clear_module_state->__pyx_n_s_x1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_x2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_xx1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_xx2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_y1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_y2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_yy1);
+ Py_CLEAR(clear_module_state->__pyx_n_s_yy2);
+ Py_CLEAR(clear_module_state->__pyx_n_s_zeros);
+ Py_CLEAR(clear_module_state->__pyx_int_0);
+ Py_CLEAR(clear_module_state->__pyx_int_1);
+ Py_CLEAR(clear_module_state->__pyx_int_2);
+ Py_CLEAR(clear_module_state->__pyx_int_3);
+ Py_CLEAR(clear_module_state->__pyx_int_4);
+ Py_CLEAR(clear_module_state->__pyx_int_neg_1);
+ Py_CLEAR(clear_module_state->__pyx_tuple_);
+ Py_CLEAR(clear_module_state->__pyx_slice__3);
+ Py_CLEAR(clear_module_state->__pyx_slice__9);
+ Py_CLEAR(clear_module_state->__pyx_tuple__2);
+ Py_CLEAR(clear_module_state->__pyx_tuple__4);
+ Py_CLEAR(clear_module_state->__pyx_tuple__5);
+ Py_CLEAR(clear_module_state->__pyx_tuple__6);
+ Py_CLEAR(clear_module_state->__pyx_tuple__7);
+ Py_CLEAR(clear_module_state->__pyx_tuple__8);
+ Py_CLEAR(clear_module_state->__pyx_tuple__11);
+ Py_CLEAR(clear_module_state->__pyx_tuple__13);
+ Py_CLEAR(clear_module_state->__pyx_codeobj__12);
+ Py_CLEAR(clear_module_state->__pyx_codeobj__14);
+ return 0;
+}
+#endif
+/* #### Code section: module_state_traverse ### */
+#if CYTHON_USE_MODULE_STATE
+static int __pyx_m_traverse(PyObject *m, visitproc visit, void *arg) {
+ __pyx_mstate *traverse_module_state = __pyx_mstate(m);
+ if (!traverse_module_state) return 0;
+ Py_VISIT(traverse_module_state->__pyx_d);
+ Py_VISIT(traverse_module_state->__pyx_b);
+ Py_VISIT(traverse_module_state->__pyx_cython_runtime);
+ Py_VISIT(traverse_module_state->__pyx_empty_tuple);
+ Py_VISIT(traverse_module_state->__pyx_empty_bytes);
+ Py_VISIT(traverse_module_state->__pyx_empty_unicode);
+ #ifdef __Pyx_CyFunction_USED
+ Py_VISIT(traverse_module_state->__pyx_CyFunctionType);
+ #endif
+ #ifdef __Pyx_FusedFunction_USED
+ Py_VISIT(traverse_module_state->__pyx_FusedFunctionType);
+ #endif
+ Py_VISIT(traverse_module_state->__pyx_ptype_7cpython_4type_type);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_dtype);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flatiter);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_broadcast);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ndarray);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_generic);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_number);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_integer);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_signedinteger);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_unsignedinteger);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_inexact);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_floating);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_complexfloating);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flexible);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_character);
+ Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ufunc);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ImportError);
+ Py_VISIT(traverse_module_state->__pyx_n_s_N);
+ Py_VISIT(traverse_module_state->__pyx_n_s_Nt);
+ Py_VISIT(traverse_module_state->__pyx_n_s__10);
+ Py_VISIT(traverse_module_state->__pyx_n_s__15);
+ Py_VISIT(traverse_module_state->__pyx_n_s_area);
+ Py_VISIT(traverse_module_state->__pyx_n_s_areas);
+ Py_VISIT(traverse_module_state->__pyx_n_s_argsort);
+ Py_VISIT(traverse_module_state->__pyx_n_s_asyncio_coroutines);
+ Py_VISIT(traverse_module_state->__pyx_n_s_box_area);
+ Py_VISIT(traverse_module_state->__pyx_n_s_boxes);
+ Py_VISIT(traverse_module_state->__pyx_n_s_class_getitem);
+ Py_VISIT(traverse_module_state->__pyx_n_s_cline_in_traceback);
+ Py_VISIT(traverse_module_state->__pyx_n_s_cpu_nms);
+ Py_VISIT(traverse_module_state->__pyx_n_s_cpu_soft_nms);
+ Py_VISIT(traverse_module_state->__pyx_n_s_dets);
+ Py_VISIT(traverse_module_state->__pyx_n_s_dtype);
+ Py_VISIT(traverse_module_state->__pyx_n_s_exp);
+ Py_VISIT(traverse_module_state->__pyx_n_s_h);
+ Py_VISIT(traverse_module_state->__pyx_n_s_i);
+ Py_VISIT(traverse_module_state->__pyx_n_s_i_2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_iarea);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ih);
+ Py_VISIT(traverse_module_state->__pyx_n_s_import);
+ Py_VISIT(traverse_module_state->__pyx_n_s_initializing);
+ Py_VISIT(traverse_module_state->__pyx_n_s_int);
+ Py_VISIT(traverse_module_state->__pyx_n_s_inter);
+ Py_VISIT(traverse_module_state->__pyx_n_s_is_coroutine);
+ Py_VISIT(traverse_module_state->__pyx_n_s_iw);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ix1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ix2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_iy1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_iy2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_j);
+ Py_VISIT(traverse_module_state->__pyx_n_s_j_2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_keep);
+ Py_VISIT(traverse_module_state->__pyx_n_s_main);
+ Py_VISIT(traverse_module_state->__pyx_n_s_maxpos);
+ Py_VISIT(traverse_module_state->__pyx_n_s_maxscore);
+ Py_VISIT(traverse_module_state->__pyx_n_s_method);
+ Py_VISIT(traverse_module_state->__pyx_n_s_name);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ndets);
+ Py_VISIT(traverse_module_state->__pyx_n_s_nms_cpu_nms);
+ Py_VISIT(traverse_module_state->__pyx_kp_s_nms_cpu_nms_pyx);
+ Py_VISIT(traverse_module_state->__pyx_n_s_np);
+ Py_VISIT(traverse_module_state->__pyx_n_s_numpy);
+ Py_VISIT(traverse_module_state->__pyx_kp_s_numpy_core_multiarray_failed_to);
+ Py_VISIT(traverse_module_state->__pyx_kp_s_numpy_core_umath_failed_to_impor);
+ Py_VISIT(traverse_module_state->__pyx_n_s_order);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ov);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ovr);
+ Py_VISIT(traverse_module_state->__pyx_n_s_pos);
+ Py_VISIT(traverse_module_state->__pyx_n_s_range);
+ Py_VISIT(traverse_module_state->__pyx_n_s_s);
+ Py_VISIT(traverse_module_state->__pyx_n_s_scores);
+ Py_VISIT(traverse_module_state->__pyx_n_s_sigma);
+ Py_VISIT(traverse_module_state->__pyx_n_s_spec);
+ Py_VISIT(traverse_module_state->__pyx_n_s_suppressed);
+ Py_VISIT(traverse_module_state->__pyx_n_s_test);
+ Py_VISIT(traverse_module_state->__pyx_n_s_thresh);
+ Py_VISIT(traverse_module_state->__pyx_n_s_threshold);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ts);
+ Py_VISIT(traverse_module_state->__pyx_n_s_tx1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_tx2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ty1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ty2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_ua);
+ Py_VISIT(traverse_module_state->__pyx_n_s_w);
+ Py_VISIT(traverse_module_state->__pyx_n_s_weight);
+ Py_VISIT(traverse_module_state->__pyx_n_s_x1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_x2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_xx1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_xx2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_y1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_y2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_yy1);
+ Py_VISIT(traverse_module_state->__pyx_n_s_yy2);
+ Py_VISIT(traverse_module_state->__pyx_n_s_zeros);
+ Py_VISIT(traverse_module_state->__pyx_int_0);
+ Py_VISIT(traverse_module_state->__pyx_int_1);
+ Py_VISIT(traverse_module_state->__pyx_int_2);
+ Py_VISIT(traverse_module_state->__pyx_int_3);
+ Py_VISIT(traverse_module_state->__pyx_int_4);
+ Py_VISIT(traverse_module_state->__pyx_int_neg_1);
+ Py_VISIT(traverse_module_state->__pyx_tuple_);
+ Py_VISIT(traverse_module_state->__pyx_slice__3);
+ Py_VISIT(traverse_module_state->__pyx_slice__9);
+ Py_VISIT(traverse_module_state->__pyx_tuple__2);
+ Py_VISIT(traverse_module_state->__pyx_tuple__4);
+ Py_VISIT(traverse_module_state->__pyx_tuple__5);
+ Py_VISIT(traverse_module_state->__pyx_tuple__6);
+ Py_VISIT(traverse_module_state->__pyx_tuple__7);
+ Py_VISIT(traverse_module_state->__pyx_tuple__8);
+ Py_VISIT(traverse_module_state->__pyx_tuple__11);
+ Py_VISIT(traverse_module_state->__pyx_tuple__13);
+ Py_VISIT(traverse_module_state->__pyx_codeobj__12);
+ Py_VISIT(traverse_module_state->__pyx_codeobj__14);
+ return 0;
+}
+#endif
+/* #### Code section: module_state_defines ### */
+#define __pyx_d __pyx_mstate_global->__pyx_d
+#define __pyx_b __pyx_mstate_global->__pyx_b
+#define __pyx_cython_runtime __pyx_mstate_global->__pyx_cython_runtime
+#define __pyx_empty_tuple __pyx_mstate_global->__pyx_empty_tuple
+#define __pyx_empty_bytes __pyx_mstate_global->__pyx_empty_bytes
+#define __pyx_empty_unicode __pyx_mstate_global->__pyx_empty_unicode
+#ifdef __Pyx_CyFunction_USED
+#define __pyx_CyFunctionType __pyx_mstate_global->__pyx_CyFunctionType
+#endif
+#ifdef __Pyx_FusedFunction_USED
+#define __pyx_FusedFunctionType __pyx_mstate_global->__pyx_FusedFunctionType
+#endif
+#ifdef __Pyx_Generator_USED
+#define __pyx_GeneratorType __pyx_mstate_global->__pyx_GeneratorType
+#endif
+#ifdef __Pyx_IterableCoroutine_USED
+#define __pyx_IterableCoroutineType __pyx_mstate_global->__pyx_IterableCoroutineType
+#endif
+#ifdef __Pyx_Coroutine_USED
+#define __pyx_CoroutineAwaitType __pyx_mstate_global->__pyx_CoroutineAwaitType
+#endif
+#ifdef __Pyx_Coroutine_USED
+#define __pyx_CoroutineType __pyx_mstate_global->__pyx_CoroutineType
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#define __pyx_ptype_7cpython_4type_type __pyx_mstate_global->__pyx_ptype_7cpython_4type_type
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#if CYTHON_USE_MODULE_STATE
+#endif
+#define __pyx_ptype_5numpy_dtype __pyx_mstate_global->__pyx_ptype_5numpy_dtype
+#define __pyx_ptype_5numpy_flatiter __pyx_mstate_global->__pyx_ptype_5numpy_flatiter
+#define __pyx_ptype_5numpy_broadcast __pyx_mstate_global->__pyx_ptype_5numpy_broadcast
+#define __pyx_ptype_5numpy_ndarray __pyx_mstate_global->__pyx_ptype_5numpy_ndarray
+#define __pyx_ptype_5numpy_generic __pyx_mstate_global->__pyx_ptype_5numpy_generic
+#define __pyx_ptype_5numpy_number __pyx_mstate_global->__pyx_ptype_5numpy_number
+#define __pyx_ptype_5numpy_integer __pyx_mstate_global->__pyx_ptype_5numpy_integer
+#define __pyx_ptype_5numpy_signedinteger __pyx_mstate_global->__pyx_ptype_5numpy_signedinteger
+#define __pyx_ptype_5numpy_unsignedinteger __pyx_mstate_global->__pyx_ptype_5numpy_unsignedinteger
+#define __pyx_ptype_5numpy_inexact __pyx_mstate_global->__pyx_ptype_5numpy_inexact
+#define __pyx_ptype_5numpy_floating __pyx_mstate_global->__pyx_ptype_5numpy_floating
+#define __pyx_ptype_5numpy_complexfloating __pyx_mstate_global->__pyx_ptype_5numpy_complexfloating
+#define __pyx_ptype_5numpy_flexible __pyx_mstate_global->__pyx_ptype_5numpy_flexible
+#define __pyx_ptype_5numpy_character __pyx_mstate_global->__pyx_ptype_5numpy_character
+#define __pyx_ptype_5numpy_ufunc __pyx_mstate_global->__pyx_ptype_5numpy_ufunc
+#if CYTHON_USE_MODULE_STATE
+#endif
+#define __pyx_n_s_ImportError __pyx_mstate_global->__pyx_n_s_ImportError
+#define __pyx_n_s_N __pyx_mstate_global->__pyx_n_s_N
+#define __pyx_n_s_Nt __pyx_mstate_global->__pyx_n_s_Nt
+#define __pyx_n_s__10 __pyx_mstate_global->__pyx_n_s__10
+#define __pyx_n_s__15 __pyx_mstate_global->__pyx_n_s__15
+#define __pyx_n_s_area __pyx_mstate_global->__pyx_n_s_area
+#define __pyx_n_s_areas __pyx_mstate_global->__pyx_n_s_areas
+#define __pyx_n_s_argsort __pyx_mstate_global->__pyx_n_s_argsort
+#define __pyx_n_s_asyncio_coroutines __pyx_mstate_global->__pyx_n_s_asyncio_coroutines
+#define __pyx_n_s_box_area __pyx_mstate_global->__pyx_n_s_box_area
+#define __pyx_n_s_boxes __pyx_mstate_global->__pyx_n_s_boxes
+#define __pyx_n_s_class_getitem __pyx_mstate_global->__pyx_n_s_class_getitem
+#define __pyx_n_s_cline_in_traceback __pyx_mstate_global->__pyx_n_s_cline_in_traceback
+#define __pyx_n_s_cpu_nms __pyx_mstate_global->__pyx_n_s_cpu_nms
+#define __pyx_n_s_cpu_soft_nms __pyx_mstate_global->__pyx_n_s_cpu_soft_nms
+#define __pyx_n_s_dets __pyx_mstate_global->__pyx_n_s_dets
+#define __pyx_n_s_dtype __pyx_mstate_global->__pyx_n_s_dtype
+#define __pyx_n_s_exp __pyx_mstate_global->__pyx_n_s_exp
+#define __pyx_n_s_h __pyx_mstate_global->__pyx_n_s_h
+#define __pyx_n_s_i __pyx_mstate_global->__pyx_n_s_i
+#define __pyx_n_s_i_2 __pyx_mstate_global->__pyx_n_s_i_2
+#define __pyx_n_s_iarea __pyx_mstate_global->__pyx_n_s_iarea
+#define __pyx_n_s_ih __pyx_mstate_global->__pyx_n_s_ih
+#define __pyx_n_s_import __pyx_mstate_global->__pyx_n_s_import
+#define __pyx_n_s_initializing __pyx_mstate_global->__pyx_n_s_initializing
+#define __pyx_n_s_int __pyx_mstate_global->__pyx_n_s_int
+#define __pyx_n_s_inter __pyx_mstate_global->__pyx_n_s_inter
+#define __pyx_n_s_is_coroutine __pyx_mstate_global->__pyx_n_s_is_coroutine
+#define __pyx_n_s_iw __pyx_mstate_global->__pyx_n_s_iw
+#define __pyx_n_s_ix1 __pyx_mstate_global->__pyx_n_s_ix1
+#define __pyx_n_s_ix2 __pyx_mstate_global->__pyx_n_s_ix2
+#define __pyx_n_s_iy1 __pyx_mstate_global->__pyx_n_s_iy1
+#define __pyx_n_s_iy2 __pyx_mstate_global->__pyx_n_s_iy2
+#define __pyx_n_s_j __pyx_mstate_global->__pyx_n_s_j
+#define __pyx_n_s_j_2 __pyx_mstate_global->__pyx_n_s_j_2
+#define __pyx_n_s_keep __pyx_mstate_global->__pyx_n_s_keep
+#define __pyx_n_s_main __pyx_mstate_global->__pyx_n_s_main
+#define __pyx_n_s_maxpos __pyx_mstate_global->__pyx_n_s_maxpos
+#define __pyx_n_s_maxscore __pyx_mstate_global->__pyx_n_s_maxscore
+#define __pyx_n_s_method __pyx_mstate_global->__pyx_n_s_method
+#define __pyx_n_s_name __pyx_mstate_global->__pyx_n_s_name
+#define __pyx_n_s_ndets __pyx_mstate_global->__pyx_n_s_ndets
+#define __pyx_n_s_nms_cpu_nms __pyx_mstate_global->__pyx_n_s_nms_cpu_nms
+#define __pyx_kp_s_nms_cpu_nms_pyx __pyx_mstate_global->__pyx_kp_s_nms_cpu_nms_pyx
+#define __pyx_n_s_np __pyx_mstate_global->__pyx_n_s_np
+#define __pyx_n_s_numpy __pyx_mstate_global->__pyx_n_s_numpy
+#define __pyx_kp_s_numpy_core_multiarray_failed_to __pyx_mstate_global->__pyx_kp_s_numpy_core_multiarray_failed_to
+#define __pyx_kp_s_numpy_core_umath_failed_to_impor __pyx_mstate_global->__pyx_kp_s_numpy_core_umath_failed_to_impor
+#define __pyx_n_s_order __pyx_mstate_global->__pyx_n_s_order
+#define __pyx_n_s_ov __pyx_mstate_global->__pyx_n_s_ov
+#define __pyx_n_s_ovr __pyx_mstate_global->__pyx_n_s_ovr
+#define __pyx_n_s_pos __pyx_mstate_global->__pyx_n_s_pos
+#define __pyx_n_s_range __pyx_mstate_global->__pyx_n_s_range
+#define __pyx_n_s_s __pyx_mstate_global->__pyx_n_s_s
+#define __pyx_n_s_scores __pyx_mstate_global->__pyx_n_s_scores
+#define __pyx_n_s_sigma __pyx_mstate_global->__pyx_n_s_sigma
+#define __pyx_n_s_spec __pyx_mstate_global->__pyx_n_s_spec
+#define __pyx_n_s_suppressed __pyx_mstate_global->__pyx_n_s_suppressed
+#define __pyx_n_s_test __pyx_mstate_global->__pyx_n_s_test
+#define __pyx_n_s_thresh __pyx_mstate_global->__pyx_n_s_thresh
+#define __pyx_n_s_threshold __pyx_mstate_global->__pyx_n_s_threshold
+#define __pyx_n_s_ts __pyx_mstate_global->__pyx_n_s_ts
+#define __pyx_n_s_tx1 __pyx_mstate_global->__pyx_n_s_tx1
+#define __pyx_n_s_tx2 __pyx_mstate_global->__pyx_n_s_tx2
+#define __pyx_n_s_ty1 __pyx_mstate_global->__pyx_n_s_ty1
+#define __pyx_n_s_ty2 __pyx_mstate_global->__pyx_n_s_ty2
+#define __pyx_n_s_ua __pyx_mstate_global->__pyx_n_s_ua
+#define __pyx_n_s_w __pyx_mstate_global->__pyx_n_s_w
+#define __pyx_n_s_weight __pyx_mstate_global->__pyx_n_s_weight
+#define __pyx_n_s_x1 __pyx_mstate_global->__pyx_n_s_x1
+#define __pyx_n_s_x2 __pyx_mstate_global->__pyx_n_s_x2
+#define __pyx_n_s_xx1 __pyx_mstate_global->__pyx_n_s_xx1
+#define __pyx_n_s_xx2 __pyx_mstate_global->__pyx_n_s_xx2
+#define __pyx_n_s_y1 __pyx_mstate_global->__pyx_n_s_y1
+#define __pyx_n_s_y2 __pyx_mstate_global->__pyx_n_s_y2
+#define __pyx_n_s_yy1 __pyx_mstate_global->__pyx_n_s_yy1
+#define __pyx_n_s_yy2 __pyx_mstate_global->__pyx_n_s_yy2
+#define __pyx_n_s_zeros __pyx_mstate_global->__pyx_n_s_zeros
+#define __pyx_int_0 __pyx_mstate_global->__pyx_int_0
+#define __pyx_int_1 __pyx_mstate_global->__pyx_int_1
+#define __pyx_int_2 __pyx_mstate_global->__pyx_int_2
+#define __pyx_int_3 __pyx_mstate_global->__pyx_int_3
+#define __pyx_int_4 __pyx_mstate_global->__pyx_int_4
+#define __pyx_int_neg_1 __pyx_mstate_global->__pyx_int_neg_1
+#define __pyx_tuple_ __pyx_mstate_global->__pyx_tuple_
+#define __pyx_slice__3 __pyx_mstate_global->__pyx_slice__3
+#define __pyx_slice__9 __pyx_mstate_global->__pyx_slice__9
+#define __pyx_tuple__2 __pyx_mstate_global->__pyx_tuple__2
+#define __pyx_tuple__4 __pyx_mstate_global->__pyx_tuple__4
+#define __pyx_tuple__5 __pyx_mstate_global->__pyx_tuple__5
+#define __pyx_tuple__6 __pyx_mstate_global->__pyx_tuple__6
+#define __pyx_tuple__7 __pyx_mstate_global->__pyx_tuple__7
+#define __pyx_tuple__8 __pyx_mstate_global->__pyx_tuple__8
+#define __pyx_tuple__11 __pyx_mstate_global->__pyx_tuple__11
+#define __pyx_tuple__13 __pyx_mstate_global->__pyx_tuple__13
+#define __pyx_codeobj__12 __pyx_mstate_global->__pyx_codeobj__12
+#define __pyx_codeobj__14 __pyx_mstate_global->__pyx_codeobj__14
+/* #### Code section: module_code ### */
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245
+ *
+ * @property
+ * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a borrowed reference to the object owning the data/memory.
+ * """
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self) {
+ PyObject *__pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":248
+ * """Returns a borrowed reference to the object owning the data/memory.
+ * """
+ * return PyArray_BASE(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __pyx_r = PyArray_BASE(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245
+ *
+ * @property
+ * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a borrowed reference to the object owning the data/memory.
+ * """
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251
+ *
+ * @property
+ * cdef inline dtype descr(self): # <<<<<<<<<<<<<<
+ * """Returns an owned reference to the dtype of the array.
+ * """
+ */
+
+static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self) {
+ PyArray_Descr *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyArray_Descr *__pyx_t_1;
+ __Pyx_RefNannySetupContext("descr", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":254
+ * """Returns an owned reference to the dtype of the array.
+ * """
+ * return PyArray_DESCR(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __Pyx_XDECREF((PyObject *)__pyx_r);
+ __pyx_t_1 = PyArray_DESCR(__pyx_v_self);
+ __Pyx_INCREF((PyObject *)((PyArray_Descr *)__pyx_t_1));
+ __pyx_r = ((PyArray_Descr *)__pyx_t_1);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251
+ *
+ * @property
+ * cdef inline dtype descr(self): # <<<<<<<<<<<<<<
+ * """Returns an owned reference to the dtype of the array.
+ * """
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ __Pyx_XGIVEREF((PyObject *)__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257
+ *
+ * @property
+ * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns the number of dimensions in the array.
+ * """
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self) {
+ int __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":260
+ * """Returns the number of dimensions in the array.
+ * """
+ * return PyArray_NDIM(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __pyx_r = PyArray_NDIM(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257
+ *
+ * @property
+ * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns the number of dimensions in the array.
+ * """
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263
+ *
+ * @property
+ * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a pointer to the dimensions/shape of the array.
+ * The number of elements matches the number of dimensions of the array (ndim).
+ */
+
+static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self) {
+ npy_intp *__pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":268
+ * Can return NULL for 0-dimensional arrays.
+ * """
+ * return PyArray_DIMS(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __pyx_r = PyArray_DIMS(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263
+ *
+ * @property
+ * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a pointer to the dimensions/shape of the array.
+ * The number of elements matches the number of dimensions of the array (ndim).
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271
+ *
+ * @property
+ * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a pointer to the strides of the array.
+ * The number of elements matches the number of dimensions of the array (ndim).
+ */
+
+static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self) {
+ npy_intp *__pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":275
+ * The number of elements matches the number of dimensions of the array (ndim).
+ * """
+ * return PyArray_STRIDES(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __pyx_r = PyArray_STRIDES(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271
+ *
+ * @property
+ * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns a pointer to the strides of the array.
+ * The number of elements matches the number of dimensions of the array (ndim).
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278
+ *
+ * @property
+ * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns the total size (in number of elements) of the array.
+ * """
+ */
+
+static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self) {
+ npy_intp __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":281
+ * """Returns the total size (in number of elements) of the array.
+ * """
+ * return PyArray_SIZE(self) # <<<<<<<<<<<<<<
+ *
+ * @property
+ */
+ __pyx_r = PyArray_SIZE(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278
+ *
+ * @property
+ * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<<
+ * """Returns the total size (in number of elements) of the array.
+ * """
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284
+ *
+ * @property
+ * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<<
+ * """The pointer to the data buffer as a char*.
+ * This is provided for legacy reasons to avoid direct struct field access.
+ */
+
+static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self) {
+ char *__pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":290
+ * of `PyArray_DATA()` instead, which returns a 'void*'.
+ * """
+ * return PyArray_BYTES(self) # <<<<<<<<<<<<<<
+ *
+ * ctypedef unsigned char npy_bool
+ */
+ __pyx_r = PyArray_BYTES(__pyx_v_self);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284
+ *
+ * @property
+ * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<<
+ * """The pointer to the data buffer as a char*.
+ * This is provided for legacy reasons to avoid direct struct field access.
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776
+ * ctypedef npy_cdouble complex_t
+ *
+ * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(1, a)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew1(PyObject *__pyx_v_a) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew1", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":777
+ *
+ * cdef inline object PyArray_MultiIterNew1(a):
+ * return PyArray_MultiIterNew(1, a) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(1, ((void *)__pyx_v_a)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 777, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776
+ * ctypedef npy_cdouble complex_t
+ *
+ * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(1, a)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew1", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779
+ * return PyArray_MultiIterNew(1, a)
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew2(PyObject *__pyx_v_a, PyObject *__pyx_v_b) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew2", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":780
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b):
+ * return PyArray_MultiIterNew(2, a, b) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(2, ((void *)__pyx_v_a), ((void *)__pyx_v_b)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 780, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779
+ * return PyArray_MultiIterNew(1, a)
+ *
+ * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew2", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew3(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew3", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":783
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c):
+ * return PyArray_MultiIterNew(3, a, b, c) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(3, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 783, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782
+ * return PyArray_MultiIterNew(2, a, b)
+ *
+ * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew3", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew4(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew4", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":786
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d):
+ * return PyArray_MultiIterNew(4, a, b, c, d) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(4, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 786, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785
+ * return PyArray_MultiIterNew(3, a, b, c)
+ *
+ * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew4", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew5(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d, PyObject *__pyx_v_e) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("PyArray_MultiIterNew5", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":789
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e):
+ * return PyArray_MultiIterNew(5, a, b, c, d, e) # <<<<<<<<<<<<<<
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_t_1 = PyArray_MultiIterNew(5, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d), ((void *)__pyx_v_e)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 789, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_r = __pyx_t_1;
+ __pyx_t_1 = 0;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788
+ * return PyArray_MultiIterNew(4, a, b, c, d)
+ *
+ * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<<
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_AddTraceback("numpy.PyArray_MultiIterNew5", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = 0;
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":791
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<<
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyDataType_SHAPE(PyArray_Descr *__pyx_v_d) {
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ int __pyx_t_1;
+ __Pyx_RefNannySetupContext("PyDataType_SHAPE", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":792
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<<
+ * return d.subarray.shape
+ * else:
+ */
+ __pyx_t_1 = PyDataType_HASSUBARRAY(__pyx_v_d);
+ if (__pyx_t_1) {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":793
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape # <<<<<<<<<<<<<<
+ * else:
+ * return ()
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(((PyObject*)__pyx_v_d->subarray->shape));
+ __pyx_r = ((PyObject*)__pyx_v_d->subarray->shape);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":792
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d):
+ * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<<
+ * return d.subarray.shape
+ * else:
+ */
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":795
+ * return d.subarray.shape
+ * else:
+ * return () # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ /*else*/ {
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(__pyx_empty_tuple);
+ __pyx_r = __pyx_empty_tuple;
+ goto __pyx_L0;
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":791
+ * return PyArray_MultiIterNew(5, a, b, c, d, e)
+ *
+ * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<<
+ * if PyDataType_HASSUBARRAY(d):
+ * return d.subarray.shape
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":970
+ * int _import_umath() except -1
+ *
+ * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<<
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base)
+ */
+
+static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_arr, PyObject *__pyx_v_base) {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":971
+ *
+ * cdef inline void set_array_base(ndarray arr, object base):
+ * Py_INCREF(base) # important to do this before stealing the reference below! # <<<<<<<<<<<<<<
+ * PyArray_SetBaseObject(arr, base)
+ *
+ */
+ Py_INCREF(__pyx_v_base);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":972
+ * cdef inline void set_array_base(ndarray arr, object base):
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base) # <<<<<<<<<<<<<<
+ *
+ * cdef inline object get_array_base(ndarray arr):
+ */
+ (void)(PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_base));
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":970
+ * int _import_umath() except -1
+ *
+ * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<<
+ * Py_INCREF(base) # important to do this before stealing the reference below!
+ * PyArray_SetBaseObject(arr, base)
+ */
+
+ /* function exit code */
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974
+ * PyArray_SetBaseObject(arr, base)
+ *
+ * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<<
+ * base = PyArray_BASE(arr)
+ * if base is NULL:
+ */
+
+static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__pyx_v_arr) {
+ PyObject *__pyx_v_base;
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ int __pyx_t_1;
+ __Pyx_RefNannySetupContext("get_array_base", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":975
+ *
+ * cdef inline object get_array_base(ndarray arr):
+ * base = PyArray_BASE(arr) # <<<<<<<<<<<<<<
+ * if base is NULL:
+ * return None
+ */
+ __pyx_v_base = PyArray_BASE(__pyx_v_arr);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":976
+ * cdef inline object get_array_base(ndarray arr):
+ * base = PyArray_BASE(arr)
+ * if base is NULL: # <<<<<<<<<<<<<<
+ * return None
+ * return base
+ */
+ __pyx_t_1 = (__pyx_v_base == NULL);
+ if (__pyx_t_1) {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":977
+ * base = PyArray_BASE(arr)
+ * if base is NULL:
+ * return None # <<<<<<<<<<<<<<
+ * return base
+ *
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __pyx_r = Py_None; __Pyx_INCREF(Py_None);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":976
+ * cdef inline object get_array_base(ndarray arr):
+ * base = PyArray_BASE(arr)
+ * if base is NULL: # <<<<<<<<<<<<<<
+ * return None
+ * return base
+ */
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":978
+ * if base is NULL:
+ * return None
+ * return base # <<<<<<<<<<<<<<
+ *
+ * # Versions of the import_* functions which are more suitable for
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(((PyObject *)__pyx_v_base));
+ __pyx_r = ((PyObject *)__pyx_v_base);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974
+ * PyArray_SetBaseObject(arr, base)
+ *
+ * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<<
+ * base = PyArray_BASE(arr)
+ * if base is NULL:
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":982
+ * # Versions of the import_* functions which are more suitable for
+ * # Cython code.
+ * cdef inline int import_array() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * __pyx_import_array()
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_import_array(void) {
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ int __pyx_t_4;
+ PyObject *__pyx_t_5 = NULL;
+ PyObject *__pyx_t_6 = NULL;
+ PyObject *__pyx_t_7 = NULL;
+ PyObject *__pyx_t_8 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("import_array", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":983
+ * # Cython code.
+ * cdef inline int import_array() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * __pyx_import_array()
+ * except Exception:
+ */
+ {
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3);
+ __Pyx_XGOTREF(__pyx_t_1);
+ __Pyx_XGOTREF(__pyx_t_2);
+ __Pyx_XGOTREF(__pyx_t_3);
+ /*try:*/ {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":984
+ * cdef inline int import_array() except -1:
+ * try:
+ * __pyx_import_array() # <<<<<<<<<<<<<<
+ * except Exception:
+ * raise ImportError("numpy.core.multiarray failed to import")
+ */
+ __pyx_t_4 = _import_array(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 984, __pyx_L3_error)
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":983
+ * # Cython code.
+ * cdef inline int import_array() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * __pyx_import_array()
+ * except Exception:
+ */
+ }
+ __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
+ goto __pyx_L8_try_end;
+ __pyx_L3_error:;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":985
+ * try:
+ * __pyx_import_array()
+ * except Exception: # <<<<<<<<<<<<<<
+ * raise ImportError("numpy.core.multiarray failed to import")
+ *
+ */
+ __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0])));
+ if (__pyx_t_4) {
+ __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(1, 985, __pyx_L5_except_error)
+ __Pyx_XGOTREF(__pyx_t_5);
+ __Pyx_XGOTREF(__pyx_t_6);
+ __Pyx_XGOTREF(__pyx_t_7);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986
+ * __pyx_import_array()
+ * except Exception:
+ * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<<
+ *
+ * cdef inline int import_umath() except -1:
+ */
+ __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple_, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 986, __pyx_L5_except_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_Raise(__pyx_t_8, 0, 0, 0);
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ __PYX_ERR(1, 986, __pyx_L5_except_error)
+ }
+ goto __pyx_L5_except_error;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":983
+ * # Cython code.
+ * cdef inline int import_array() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * __pyx_import_array()
+ * except Exception:
+ */
+ __pyx_L5_except_error:;
+ __Pyx_XGIVEREF(__pyx_t_1);
+ __Pyx_XGIVEREF(__pyx_t_2);
+ __Pyx_XGIVEREF(__pyx_t_3);
+ __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3);
+ goto __pyx_L1_error;
+ __pyx_L8_try_end:;
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":982
+ * # Versions of the import_* functions which are more suitable for
+ * # Cython code.
+ * cdef inline int import_array() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * __pyx_import_array()
+ */
+
+ /* function exit code */
+ __pyx_r = 0;
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_5);
+ __Pyx_XDECREF(__pyx_t_6);
+ __Pyx_XDECREF(__pyx_t_7);
+ __Pyx_XDECREF(__pyx_t_8);
+ __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = -1;
+ __pyx_L0:;
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":988
+ * raise ImportError("numpy.core.multiarray failed to import")
+ *
+ * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * _import_umath()
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_import_umath(void) {
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ int __pyx_t_4;
+ PyObject *__pyx_t_5 = NULL;
+ PyObject *__pyx_t_6 = NULL;
+ PyObject *__pyx_t_7 = NULL;
+ PyObject *__pyx_t_8 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("import_umath", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":989
+ *
+ * cdef inline int import_umath() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ {
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3);
+ __Pyx_XGOTREF(__pyx_t_1);
+ __Pyx_XGOTREF(__pyx_t_2);
+ __Pyx_XGOTREF(__pyx_t_3);
+ /*try:*/ {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":990
+ * cdef inline int import_umath() except -1:
+ * try:
+ * _import_umath() # <<<<<<<<<<<<<<
+ * except Exception:
+ * raise ImportError("numpy.core.umath failed to import")
+ */
+ __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 990, __pyx_L3_error)
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":989
+ *
+ * cdef inline int import_umath() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ }
+ __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
+ goto __pyx_L8_try_end;
+ __pyx_L3_error:;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":991
+ * try:
+ * _import_umath()
+ * except Exception: # <<<<<<<<<<<<<<
+ * raise ImportError("numpy.core.umath failed to import")
+ *
+ */
+ __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0])));
+ if (__pyx_t_4) {
+ __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(1, 991, __pyx_L5_except_error)
+ __Pyx_XGOTREF(__pyx_t_5);
+ __Pyx_XGOTREF(__pyx_t_6);
+ __Pyx_XGOTREF(__pyx_t_7);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992
+ * _import_umath()
+ * except Exception:
+ * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<<
+ *
+ * cdef inline int import_ufunc() except -1:
+ */
+ __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__2, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 992, __pyx_L5_except_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_Raise(__pyx_t_8, 0, 0, 0);
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ __PYX_ERR(1, 992, __pyx_L5_except_error)
+ }
+ goto __pyx_L5_except_error;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":989
+ *
+ * cdef inline int import_umath() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ __pyx_L5_except_error:;
+ __Pyx_XGIVEREF(__pyx_t_1);
+ __Pyx_XGIVEREF(__pyx_t_2);
+ __Pyx_XGIVEREF(__pyx_t_3);
+ __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3);
+ goto __pyx_L1_error;
+ __pyx_L8_try_end:;
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":988
+ * raise ImportError("numpy.core.multiarray failed to import")
+ *
+ * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * _import_umath()
+ */
+
+ /* function exit code */
+ __pyx_r = 0;
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_5);
+ __Pyx_XDECREF(__pyx_t_6);
+ __Pyx_XDECREF(__pyx_t_7);
+ __Pyx_XDECREF(__pyx_t_8);
+ __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = -1;
+ __pyx_L0:;
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":994
+ * raise ImportError("numpy.core.umath failed to import")
+ *
+ * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * _import_umath()
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_import_ufunc(void) {
+ int __pyx_r;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ int __pyx_t_4;
+ PyObject *__pyx_t_5 = NULL;
+ PyObject *__pyx_t_6 = NULL;
+ PyObject *__pyx_t_7 = NULL;
+ PyObject *__pyx_t_8 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("import_ufunc", 1);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":995
+ *
+ * cdef inline int import_ufunc() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ {
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3);
+ __Pyx_XGOTREF(__pyx_t_1);
+ __Pyx_XGOTREF(__pyx_t_2);
+ __Pyx_XGOTREF(__pyx_t_3);
+ /*try:*/ {
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":996
+ * cdef inline int import_ufunc() except -1:
+ * try:
+ * _import_umath() # <<<<<<<<<<<<<<
+ * except Exception:
+ * raise ImportError("numpy.core.umath failed to import")
+ */
+ __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 996, __pyx_L3_error)
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":995
+ *
+ * cdef inline int import_ufunc() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ }
+ __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
+ goto __pyx_L8_try_end;
+ __pyx_L3_error:;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":997
+ * try:
+ * _import_umath()
+ * except Exception: # <<<<<<<<<<<<<<
+ * raise ImportError("numpy.core.umath failed to import")
+ *
+ */
+ __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0])));
+ if (__pyx_t_4) {
+ __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(1, 997, __pyx_L5_except_error)
+ __Pyx_XGOTREF(__pyx_t_5);
+ __Pyx_XGOTREF(__pyx_t_6);
+ __Pyx_XGOTREF(__pyx_t_7);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":998
+ * _import_umath()
+ * except Exception:
+ * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__2, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 998, __pyx_L5_except_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_Raise(__pyx_t_8, 0, 0, 0);
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ __PYX_ERR(1, 998, __pyx_L5_except_error)
+ }
+ goto __pyx_L5_except_error;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":995
+ *
+ * cdef inline int import_ufunc() except -1:
+ * try: # <<<<<<<<<<<<<<
+ * _import_umath()
+ * except Exception:
+ */
+ __pyx_L5_except_error:;
+ __Pyx_XGIVEREF(__pyx_t_1);
+ __Pyx_XGIVEREF(__pyx_t_2);
+ __Pyx_XGIVEREF(__pyx_t_3);
+ __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3);
+ goto __pyx_L1_error;
+ __pyx_L8_try_end:;
+ }
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":994
+ * raise ImportError("numpy.core.umath failed to import")
+ *
+ * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<<
+ * try:
+ * _import_umath()
+ */
+
+ /* function exit code */
+ __pyx_r = 0;
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_5);
+ __Pyx_XDECREF(__pyx_t_6);
+ __Pyx_XDECREF(__pyx_t_7);
+ __Pyx_XDECREF(__pyx_t_8);
+ __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = -1;
+ __pyx_L0:;
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1001
+ *
+ *
+ * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<<
+ * """
+ * Cython equivalent of `isinstance(obj, np.timedelta64)`
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_is_timedelta64_object(PyObject *__pyx_v_obj) {
+ int __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1013
+ * bool
+ * """
+ * return PyObject_TypeCheck(obj, &PyTimedeltaArrType_Type) # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyTimedeltaArrType_Type));
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1001
+ *
+ *
+ * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<<
+ * """
+ * Cython equivalent of `isinstance(obj, np.timedelta64)`
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1016
+ *
+ *
+ * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<<
+ * """
+ * Cython equivalent of `isinstance(obj, np.datetime64)`
+ */
+
+static CYTHON_INLINE int __pyx_f_5numpy_is_datetime64_object(PyObject *__pyx_v_obj) {
+ int __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1028
+ * bool
+ * """
+ * return PyObject_TypeCheck(obj, &PyDatetimeArrType_Type) # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyDatetimeArrType_Type));
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1016
+ *
+ *
+ * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<<
+ * """
+ * Cython equivalent of `isinstance(obj, np.datetime64)`
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1031
+ *
+ *
+ * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the int64 value underlying scalar numpy datetime64 object
+ */
+
+static CYTHON_INLINE npy_datetime __pyx_f_5numpy_get_datetime64_value(PyObject *__pyx_v_obj) {
+ npy_datetime __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1038
+ * also needed. That can be found using `get_datetime64_unit`.
+ * """
+ * return (obj).obval # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ __pyx_r = ((PyDatetimeScalarObject *)__pyx_v_obj)->obval;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1031
+ *
+ *
+ * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the int64 value underlying scalar numpy datetime64 object
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1041
+ *
+ *
+ * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the int64 value underlying scalar numpy timedelta64 object
+ */
+
+static CYTHON_INLINE npy_timedelta __pyx_f_5numpy_get_timedelta64_value(PyObject *__pyx_v_obj) {
+ npy_timedelta __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1045
+ * returns the int64 value underlying scalar numpy timedelta64 object
+ * """
+ * return (obj).obval # <<<<<<<<<<<<<<
+ *
+ *
+ */
+ __pyx_r = ((PyTimedeltaScalarObject *)__pyx_v_obj)->obval;
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1041
+ *
+ *
+ * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the int64 value underlying scalar numpy timedelta64 object
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1048
+ *
+ *
+ * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the unit part of the dtype for a numpy datetime64 object.
+ */
+
+static CYTHON_INLINE NPY_DATETIMEUNIT __pyx_f_5numpy_get_datetime64_unit(PyObject *__pyx_v_obj) {
+ NPY_DATETIMEUNIT __pyx_r;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1052
+ * returns the unit part of the dtype for a numpy datetime64 object.
+ * """
+ * return (obj).obmeta.base # <<<<<<<<<<<<<<
+ */
+ __pyx_r = ((NPY_DATETIMEUNIT)((PyDatetimeScalarObject *)__pyx_v_obj)->obmeta.base);
+ goto __pyx_L0;
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1048
+ *
+ *
+ * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<<
+ * """
+ * returns the unit part of the dtype for a numpy datetime64 object.
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "nms/cpu_nms.pyx":11
+ * cimport numpy as np
+ *
+ * cdef inline np.float32_t max(np.float32_t a, np.float32_t b): # <<<<<<<<<<<<<<
+ * return a if a >= b else b
+ *
+ */
+
+static CYTHON_INLINE __pyx_t_5numpy_float32_t __pyx_f_3nms_7cpu_nms_max(__pyx_t_5numpy_float32_t __pyx_v_a, __pyx_t_5numpy_float32_t __pyx_v_b) {
+ __pyx_t_5numpy_float32_t __pyx_r;
+ __pyx_t_5numpy_float32_t __pyx_t_1;
+ int __pyx_t_2;
+
+ /* "nms/cpu_nms.pyx":12
+ *
+ * cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
+ * return a if a >= b else b # <<<<<<<<<<<<<<
+ *
+ * cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
+ */
+ __pyx_t_2 = (__pyx_v_a >= __pyx_v_b);
+ if (__pyx_t_2) {
+ __pyx_t_1 = __pyx_v_a;
+ } else {
+ __pyx_t_1 = __pyx_v_b;
+ }
+ __pyx_r = __pyx_t_1;
+ goto __pyx_L0;
+
+ /* "nms/cpu_nms.pyx":11
+ * cimport numpy as np
+ *
+ * cdef inline np.float32_t max(np.float32_t a, np.float32_t b): # <<<<<<<<<<<<<<
+ * return a if a >= b else b
+ *
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "nms/cpu_nms.pyx":14
+ * return a if a >= b else b
+ *
+ * cdef inline np.float32_t min(np.float32_t a, np.float32_t b): # <<<<<<<<<<<<<<
+ * return a if a <= b else b
+ *
+ */
+
+static CYTHON_INLINE __pyx_t_5numpy_float32_t __pyx_f_3nms_7cpu_nms_min(__pyx_t_5numpy_float32_t __pyx_v_a, __pyx_t_5numpy_float32_t __pyx_v_b) {
+ __pyx_t_5numpy_float32_t __pyx_r;
+ __pyx_t_5numpy_float32_t __pyx_t_1;
+ int __pyx_t_2;
+
+ /* "nms/cpu_nms.pyx":15
+ *
+ * cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
+ * return a if a <= b else b # <<<<<<<<<<<<<<
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ */
+ __pyx_t_2 = (__pyx_v_a <= __pyx_v_b);
+ if (__pyx_t_2) {
+ __pyx_t_1 = __pyx_v_a;
+ } else {
+ __pyx_t_1 = __pyx_v_b;
+ }
+ __pyx_r = __pyx_t_1;
+ goto __pyx_L0;
+
+ /* "nms/cpu_nms.pyx":14
+ * return a if a >= b else b
+ *
+ * cdef inline np.float32_t min(np.float32_t a, np.float32_t b): # <<<<<<<<<<<<<<
+ * return a if a <= b else b
+ *
+ */
+
+ /* function exit code */
+ __pyx_L0:;
+ return __pyx_r;
+}
+
+/* "nms/cpu_nms.pyx":17
+ * return a if a <= b else b
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ */
+
+/* Python wrapper */
+static PyObject *__pyx_pw_3nms_7cpu_nms_1cpu_nms(PyObject *__pyx_self,
+#if CYTHON_METH_FASTCALL
+PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
+#else
+PyObject *__pyx_args, PyObject *__pyx_kwds
+#endif
+); /*proto*/
+static PyMethodDef __pyx_mdef_3nms_7cpu_nms_1cpu_nms = {"cpu_nms", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_3nms_7cpu_nms_1cpu_nms, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
+static PyObject *__pyx_pw_3nms_7cpu_nms_1cpu_nms(PyObject *__pyx_self,
+#if CYTHON_METH_FASTCALL
+PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
+#else
+PyObject *__pyx_args, PyObject *__pyx_kwds
+#endif
+) {
+ PyArrayObject *__pyx_v_dets = 0;
+ PyObject *__pyx_v_thresh = 0;
+ #if !CYTHON_METH_FASTCALL
+ CYTHON_UNUSED Py_ssize_t __pyx_nargs;
+ #endif
+ CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
+ PyObject* values[2] = {0,0};
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ PyObject *__pyx_r = 0;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("cpu_nms (wrapper)", 0);
+ #if !CYTHON_METH_FASTCALL
+ #if CYTHON_ASSUME_SAFE_MACROS
+ __pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
+ #else
+ __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
+ #endif
+ #endif
+ __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
+ {
+ PyObject **__pyx_pyargnames[] = {&__pyx_n_s_dets,&__pyx_n_s_thresh,0};
+ if (__pyx_kwds) {
+ Py_ssize_t kw_args;
+ switch (__pyx_nargs) {
+ case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1);
+ CYTHON_FALLTHROUGH;
+ case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);
+ CYTHON_FALLTHROUGH;
+ case 0: break;
+ default: goto __pyx_L5_argtuple_error;
+ }
+ kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds);
+ switch (__pyx_nargs) {
+ case 0:
+ if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_dets)) != 0)) {
+ (void)__Pyx_Arg_NewRef_FASTCALL(values[0]);
+ kw_args--;
+ }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 17, __pyx_L3_error)
+ else goto __pyx_L5_argtuple_error;
+ CYTHON_FALLTHROUGH;
+ case 1:
+ if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_thresh)) != 0)) {
+ (void)__Pyx_Arg_NewRef_FASTCALL(values[1]);
+ kw_args--;
+ }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 17, __pyx_L3_error)
+ else {
+ __Pyx_RaiseArgtupleInvalid("cpu_nms", 1, 2, 2, 1); __PYX_ERR(0, 17, __pyx_L3_error)
+ }
+ }
+ if (unlikely(kw_args > 0)) {
+ const Py_ssize_t kwd_pos_args = __pyx_nargs;
+ if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "cpu_nms") < 0)) __PYX_ERR(0, 17, __pyx_L3_error)
+ }
+ } else if (unlikely(__pyx_nargs != 2)) {
+ goto __pyx_L5_argtuple_error;
+ } else {
+ values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);
+ values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1);
+ }
+ __pyx_v_dets = ((PyArrayObject *)values[0]);
+ __pyx_v_thresh = ((PyObject*)values[1]);
+ }
+ goto __pyx_L6_skip;
+ __pyx_L5_argtuple_error:;
+ __Pyx_RaiseArgtupleInvalid("cpu_nms", 1, 2, 2, __pyx_nargs); __PYX_ERR(0, 17, __pyx_L3_error)
+ __pyx_L6_skip:;
+ goto __pyx_L4_argument_unpacking_done;
+ __pyx_L3_error:;
+ {
+ Py_ssize_t __pyx_temp;
+ for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
+ __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);
+ }
+ }
+ __Pyx_AddTraceback("nms.cpu_nms.cpu_nms", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_RefNannyFinishContext();
+ return NULL;
+ __pyx_L4_argument_unpacking_done:;
+ if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_dets), __pyx_ptype_5numpy_ndarray, 1, "dets", 0))) __PYX_ERR(0, 17, __pyx_L1_error)
+ if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_thresh), (&PyFloat_Type), 1, "thresh", 1))) __PYX_ERR(0, 17, __pyx_L1_error)
+ __pyx_r = __pyx_pf_3nms_7cpu_nms_cpu_nms(__pyx_self, __pyx_v_dets, __pyx_v_thresh);
+
+ /* function exit code */
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __pyx_r = NULL;
+ __pyx_L0:;
+ {
+ Py_ssize_t __pyx_temp;
+ for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
+ __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);
+ }
+ }
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyObject *__pyx_pf_3nms_7cpu_nms_cpu_nms(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_dets, PyObject *__pyx_v_thresh) {
+ PyArrayObject *__pyx_v_x1 = 0;
+ PyArrayObject *__pyx_v_y1 = 0;
+ PyArrayObject *__pyx_v_x2 = 0;
+ PyArrayObject *__pyx_v_y2 = 0;
+ PyArrayObject *__pyx_v_scores = 0;
+ PyArrayObject *__pyx_v_areas = 0;
+ PyArrayObject *__pyx_v_order = 0;
+ int __pyx_v_ndets;
+ PyArrayObject *__pyx_v_suppressed = 0;
+ int __pyx_v__i;
+ int __pyx_v__j;
+ int __pyx_v_i;
+ int __pyx_v_j;
+ __pyx_t_5numpy_float32_t __pyx_v_ix1;
+ __pyx_t_5numpy_float32_t __pyx_v_iy1;
+ __pyx_t_5numpy_float32_t __pyx_v_ix2;
+ __pyx_t_5numpy_float32_t __pyx_v_iy2;
+ __pyx_t_5numpy_float32_t __pyx_v_iarea;
+ __pyx_t_5numpy_float32_t __pyx_v_xx1;
+ __pyx_t_5numpy_float32_t __pyx_v_yy1;
+ __pyx_t_5numpy_float32_t __pyx_v_xx2;
+ __pyx_t_5numpy_float32_t __pyx_v_yy2;
+ __pyx_t_5numpy_float32_t __pyx_v_w;
+ __pyx_t_5numpy_float32_t __pyx_v_h;
+ __pyx_t_5numpy_float32_t __pyx_v_inter;
+ __pyx_t_5numpy_float32_t __pyx_v_ovr;
+ PyObject *__pyx_v_keep = NULL;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_areas;
+ __Pyx_Buffer __pyx_pybuffer_areas;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_dets;
+ __Pyx_Buffer __pyx_pybuffer_dets;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_order;
+ __Pyx_Buffer __pyx_pybuffer_order;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_scores;
+ __Pyx_Buffer __pyx_pybuffer_scores;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_suppressed;
+ __Pyx_Buffer __pyx_pybuffer_suppressed;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_x1;
+ __Pyx_Buffer __pyx_pybuffer_x1;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_x2;
+ __Pyx_Buffer __pyx_pybuffer_x2;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_y1;
+ __Pyx_Buffer __pyx_pybuffer_y1;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_y2;
+ __Pyx_Buffer __pyx_pybuffer_y2;
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ PyArrayObject *__pyx_t_2 = NULL;
+ PyArrayObject *__pyx_t_3 = NULL;
+ PyArrayObject *__pyx_t_4 = NULL;
+ PyArrayObject *__pyx_t_5 = NULL;
+ PyArrayObject *__pyx_t_6 = NULL;
+ PyObject *__pyx_t_7 = NULL;
+ PyObject *__pyx_t_8 = NULL;
+ PyArrayObject *__pyx_t_9 = NULL;
+ unsigned int __pyx_t_10;
+ PyArrayObject *__pyx_t_11 = NULL;
+ npy_intp *__pyx_t_12;
+ PyObject *__pyx_t_13 = NULL;
+ PyObject *__pyx_t_14 = NULL;
+ PyArrayObject *__pyx_t_15 = NULL;
+ int __pyx_t_16;
+ int __pyx_t_17;
+ int __pyx_t_18;
+ Py_ssize_t __pyx_t_19;
+ int __pyx_t_20;
+ int __pyx_t_21;
+ int __pyx_t_22;
+ int __pyx_t_23;
+ int __pyx_t_24;
+ int __pyx_t_25;
+ __pyx_t_5numpy_float32_t __pyx_t_26;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("cpu_nms", 1);
+ __pyx_pybuffer_x1.pybuffer.buf = NULL;
+ __pyx_pybuffer_x1.refcount = 0;
+ __pyx_pybuffernd_x1.data = NULL;
+ __pyx_pybuffernd_x1.rcbuffer = &__pyx_pybuffer_x1;
+ __pyx_pybuffer_y1.pybuffer.buf = NULL;
+ __pyx_pybuffer_y1.refcount = 0;
+ __pyx_pybuffernd_y1.data = NULL;
+ __pyx_pybuffernd_y1.rcbuffer = &__pyx_pybuffer_y1;
+ __pyx_pybuffer_x2.pybuffer.buf = NULL;
+ __pyx_pybuffer_x2.refcount = 0;
+ __pyx_pybuffernd_x2.data = NULL;
+ __pyx_pybuffernd_x2.rcbuffer = &__pyx_pybuffer_x2;
+ __pyx_pybuffer_y2.pybuffer.buf = NULL;
+ __pyx_pybuffer_y2.refcount = 0;
+ __pyx_pybuffernd_y2.data = NULL;
+ __pyx_pybuffernd_y2.rcbuffer = &__pyx_pybuffer_y2;
+ __pyx_pybuffer_scores.pybuffer.buf = NULL;
+ __pyx_pybuffer_scores.refcount = 0;
+ __pyx_pybuffernd_scores.data = NULL;
+ __pyx_pybuffernd_scores.rcbuffer = &__pyx_pybuffer_scores;
+ __pyx_pybuffer_areas.pybuffer.buf = NULL;
+ __pyx_pybuffer_areas.refcount = 0;
+ __pyx_pybuffernd_areas.data = NULL;
+ __pyx_pybuffernd_areas.rcbuffer = &__pyx_pybuffer_areas;
+ __pyx_pybuffer_order.pybuffer.buf = NULL;
+ __pyx_pybuffer_order.refcount = 0;
+ __pyx_pybuffernd_order.data = NULL;
+ __pyx_pybuffernd_order.rcbuffer = &__pyx_pybuffer_order;
+ __pyx_pybuffer_suppressed.pybuffer.buf = NULL;
+ __pyx_pybuffer_suppressed.refcount = 0;
+ __pyx_pybuffernd_suppressed.data = NULL;
+ __pyx_pybuffernd_suppressed.rcbuffer = &__pyx_pybuffer_suppressed;
+ __pyx_pybuffer_dets.pybuffer.buf = NULL;
+ __pyx_pybuffer_dets.refcount = 0;
+ __pyx_pybuffernd_dets.data = NULL;
+ __pyx_pybuffernd_dets.rcbuffer = &__pyx_pybuffer_dets;
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_dets.rcbuffer->pybuffer, (PyObject*)__pyx_v_dets, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 17, __pyx_L1_error)
+ }
+ __pyx_pybuffernd_dets.diminfo[0].strides = __pyx_pybuffernd_dets.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_dets.diminfo[0].shape = __pyx_pybuffernd_dets.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_dets.diminfo[1].strides = __pyx_pybuffernd_dets.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_dets.diminfo[1].shape = __pyx_pybuffernd_dets.rcbuffer->pybuffer.shape[1];
+
+ /* "nms/cpu_nms.pyx":18
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ */
+ __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_dets), __pyx_tuple__4); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 18, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 18, __pyx_L1_error)
+ __pyx_t_2 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_x1.rcbuffer->pybuffer, (PyObject*)__pyx_t_2, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_x1 = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_x1.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 18, __pyx_L1_error)
+ } else {__pyx_pybuffernd_x1.diminfo[0].strides = __pyx_pybuffernd_x1.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_x1.diminfo[0].shape = __pyx_pybuffernd_x1.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_2 = 0;
+ __pyx_v_x1 = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":19
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ */
+ __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_dets), __pyx_tuple__5); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 19, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 19, __pyx_L1_error)
+ __pyx_t_3 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_y1.rcbuffer->pybuffer, (PyObject*)__pyx_t_3, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_y1 = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_y1.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 19, __pyx_L1_error)
+ } else {__pyx_pybuffernd_y1.diminfo[0].strides = __pyx_pybuffernd_y1.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_y1.diminfo[0].shape = __pyx_pybuffernd_y1.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_3 = 0;
+ __pyx_v_y1 = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":20
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+ */
+ __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_dets), __pyx_tuple__6); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 20, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 20, __pyx_L1_error)
+ __pyx_t_4 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_x2.rcbuffer->pybuffer, (PyObject*)__pyx_t_4, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_x2 = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_x2.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 20, __pyx_L1_error)
+ } else {__pyx_pybuffernd_x2.diminfo[0].strides = __pyx_pybuffernd_x2.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_x2.diminfo[0].shape = __pyx_pybuffernd_x2.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_4 = 0;
+ __pyx_v_x2 = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":21
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+ *
+ */
+ __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_dets), __pyx_tuple__7); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 21, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 21, __pyx_L1_error)
+ __pyx_t_5 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_y2.rcbuffer->pybuffer, (PyObject*)__pyx_t_5, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_y2 = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_y2.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 21, __pyx_L1_error)
+ } else {__pyx_pybuffernd_y2.diminfo[0].strides = __pyx_pybuffernd_y2.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_y2.diminfo[0].shape = __pyx_pybuffernd_y2.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_5 = 0;
+ __pyx_v_y2 = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":22
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] # <<<<<<<<<<<<<<
+ *
+ * cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ */
+ __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_dets), __pyx_tuple__8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 22, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 22, __pyx_L1_error)
+ __pyx_t_6 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_scores.rcbuffer->pybuffer, (PyObject*)__pyx_t_6, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_scores = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_scores.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 22, __pyx_L1_error)
+ } else {__pyx_pybuffernd_scores.diminfo[0].strides = __pyx_pybuffernd_scores.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_scores.diminfo[0].shape = __pyx_pybuffernd_scores.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_6 = 0;
+ __pyx_v_scores = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":24
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+ *
+ * cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
+ *
+ */
+ __pyx_t_1 = PyNumber_Subtract(((PyObject *)__pyx_v_x2), ((PyObject *)__pyx_v_x1)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_t_7 = __Pyx_PyInt_AddObjC(__pyx_t_1, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 24, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_7);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_t_1 = PyNumber_Subtract(((PyObject *)__pyx_v_y2), ((PyObject *)__pyx_v_y1)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_t_8 = __Pyx_PyInt_AddObjC(__pyx_t_1, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 24, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_t_1 = PyNumber_Multiply(__pyx_t_7, __pyx_t_8); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 24, __pyx_L1_error)
+ __pyx_t_9 = ((PyArrayObject *)__pyx_t_1);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_areas.rcbuffer->pybuffer, (PyObject*)__pyx_t_9, &__Pyx_TypeInfo_nn___pyx_t_5numpy_float32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_areas = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_areas.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 24, __pyx_L1_error)
+ } else {__pyx_pybuffernd_areas.diminfo[0].strides = __pyx_pybuffernd_areas.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_areas.diminfo[0].shape = __pyx_pybuffernd_areas.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_9 = 0;
+ __pyx_v_areas = ((PyArrayObject *)__pyx_t_1);
+ __pyx_t_1 = 0;
+
+ /* "nms/cpu_nms.pyx":25
+ *
+ * cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ * cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] # <<<<<<<<<<<<<<
+ *
+ * cdef int ndets = dets.shape[0]
+ */
+ __pyx_t_8 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_scores), __pyx_n_s_argsort); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 25, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __pyx_t_7 = NULL;
+ __pyx_t_10 = 0;
+ #if CYTHON_UNPACK_METHODS
+ if (likely(PyMethod_Check(__pyx_t_8))) {
+ __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_8);
+ if (likely(__pyx_t_7)) {
+ PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_8);
+ __Pyx_INCREF(__pyx_t_7);
+ __Pyx_INCREF(function);
+ __Pyx_DECREF_SET(__pyx_t_8, function);
+ __pyx_t_10 = 1;
+ }
+ }
+ #endif
+ {
+ PyObject *__pyx_callargs[2] = {__pyx_t_7, NULL};
+ __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_8, __pyx_callargs+1-__pyx_t_10, 0+__pyx_t_10);
+ __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0;
+ if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 25, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ }
+ __pyx_t_8 = __Pyx_PyObject_GetItem(__pyx_t_1, __pyx_slice__9); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 25, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ if (!(likely(((__pyx_t_8) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_8, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 25, __pyx_L1_error)
+ __pyx_t_11 = ((PyArrayObject *)__pyx_t_8);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_order.rcbuffer->pybuffer, (PyObject*)__pyx_t_11, &__Pyx_TypeInfo_nn___pyx_t_5numpy_int_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_order = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_order.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 25, __pyx_L1_error)
+ } else {__pyx_pybuffernd_order.diminfo[0].strides = __pyx_pybuffernd_order.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_order.diminfo[0].shape = __pyx_pybuffernd_order.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_11 = 0;
+ __pyx_v_order = ((PyArrayObject *)__pyx_t_8);
+ __pyx_t_8 = 0;
+
+ /* "nms/cpu_nms.pyx":27
+ * cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
+ *
+ * cdef int ndets = dets.shape[0] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.int_t, ndim=1] suppressed = \
+ * np.zeros((ndets), dtype=np.int)
+ */
+ __pyx_t_12 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_dets)); if (unlikely(__pyx_t_12 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 27, __pyx_L1_error)
+ __pyx_v_ndets = (__pyx_t_12[0]);
+
+ /* "nms/cpu_nms.pyx":29
+ * cdef int ndets = dets.shape[0]
+ * cdef np.ndarray[np.int_t, ndim=1] suppressed = \
+ * np.zeros((ndets), dtype=np.int) # <<<<<<<<<<<<<<
+ *
+ * # nominal indices
+ */
+ __Pyx_GetModuleGlobalName(__pyx_t_8, __pyx_n_s_np); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_8, __pyx_n_s_zeros); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ __pyx_t_8 = __Pyx_PyInt_From_int(__pyx_v_ndets); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __pyx_t_7 = PyTuple_New(1); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_7);
+ __Pyx_GIVEREF(__pyx_t_8);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_7, 0, __pyx_t_8)) __PYX_ERR(0, 29, __pyx_L1_error);
+ __pyx_t_8 = 0;
+ __pyx_t_8 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_8);
+ __Pyx_GetModuleGlobalName(__pyx_t_13, __pyx_n_s_np); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_13);
+ __pyx_t_14 = __Pyx_PyObject_GetAttrStr(__pyx_t_13, __pyx_n_s_int); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0;
+ if (PyDict_SetItem(__pyx_t_8, __pyx_n_s_dtype, __pyx_t_14) < 0) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0;
+ __pyx_t_14 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_7, __pyx_t_8); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 29, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ if (!(likely(((__pyx_t_14) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_14, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 29, __pyx_L1_error)
+ __pyx_t_15 = ((PyArrayObject *)__pyx_t_14);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_suppressed.rcbuffer->pybuffer, (PyObject*)__pyx_t_15, &__Pyx_TypeInfo_nn___pyx_t_5numpy_int_t, PyBUF_FORMAT| PyBUF_STRIDES| PyBUF_WRITABLE, 1, 0, __pyx_stack) == -1)) {
+ __pyx_v_suppressed = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.buf = NULL;
+ __PYX_ERR(0, 28, __pyx_L1_error)
+ } else {__pyx_pybuffernd_suppressed.diminfo[0].strides = __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_suppressed.diminfo[0].shape = __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.shape[0];
+ }
+ }
+ __pyx_t_15 = 0;
+ __pyx_v_suppressed = ((PyArrayObject *)__pyx_t_14);
+ __pyx_t_14 = 0;
+
+ /* "nms/cpu_nms.pyx":42
+ * cdef np.float32_t inter, ovr
+ *
+ * keep = [] # <<<<<<<<<<<<<<
+ * for _i in range(ndets):
+ * i = order[_i]
+ */
+ __pyx_t_14 = PyList_New(0); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 42, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __pyx_v_keep = ((PyObject*)__pyx_t_14);
+ __pyx_t_14 = 0;
+
+ /* "nms/cpu_nms.pyx":43
+ *
+ * keep = []
+ * for _i in range(ndets): # <<<<<<<<<<<<<<
+ * i = order[_i]
+ * if suppressed[i] == 1:
+ */
+ __pyx_t_16 = __pyx_v_ndets;
+ __pyx_t_17 = __pyx_t_16;
+ for (__pyx_t_18 = 0; __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
+ __pyx_v__i = __pyx_t_18;
+
+ /* "nms/cpu_nms.pyx":44
+ * keep = []
+ * for _i in range(ndets):
+ * i = order[_i] # <<<<<<<<<<<<<<
+ * if suppressed[i] == 1:
+ * continue
+ */
+ __pyx_t_19 = __pyx_v__i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_order.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_order.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 44, __pyx_L1_error)
+ }
+ __pyx_v_i = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_int_t *, __pyx_pybuffernd_order.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_order.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":45
+ * for _i in range(ndets):
+ * i = order[_i]
+ * if suppressed[i] == 1: # <<<<<<<<<<<<<<
+ * continue
+ * keep.append(i)
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_suppressed.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_suppressed.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 45, __pyx_L1_error)
+ }
+ __pyx_t_21 = ((*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_int_t *, __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_suppressed.diminfo[0].strides)) == 1);
+ if (__pyx_t_21) {
+
+ /* "nms/cpu_nms.pyx":46
+ * i = order[_i]
+ * if suppressed[i] == 1:
+ * continue # <<<<<<<<<<<<<<
+ * keep.append(i)
+ * ix1 = x1[i]
+ */
+ goto __pyx_L3_continue;
+
+ /* "nms/cpu_nms.pyx":45
+ * for _i in range(ndets):
+ * i = order[_i]
+ * if suppressed[i] == 1: # <<<<<<<<<<<<<<
+ * continue
+ * keep.append(i)
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":47
+ * if suppressed[i] == 1:
+ * continue
+ * keep.append(i) # <<<<<<<<<<<<<<
+ * ix1 = x1[i]
+ * iy1 = y1[i]
+ */
+ __pyx_t_14 = __Pyx_PyInt_From_int(__pyx_v_i); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 47, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __pyx_t_22 = __Pyx_PyList_Append(__pyx_v_keep, __pyx_t_14); if (unlikely(__pyx_t_22 == ((int)-1))) __PYX_ERR(0, 47, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0;
+
+ /* "nms/cpu_nms.pyx":48
+ * continue
+ * keep.append(i)
+ * ix1 = x1[i] # <<<<<<<<<<<<<<
+ * iy1 = y1[i]
+ * ix2 = x2[i]
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_x1.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_x1.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 48, __pyx_L1_error)
+ }
+ __pyx_v_ix1 = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_x1.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_x1.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":49
+ * keep.append(i)
+ * ix1 = x1[i]
+ * iy1 = y1[i] # <<<<<<<<<<<<<<
+ * ix2 = x2[i]
+ * iy2 = y2[i]
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_y1.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_y1.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 49, __pyx_L1_error)
+ }
+ __pyx_v_iy1 = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_y1.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_y1.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":50
+ * ix1 = x1[i]
+ * iy1 = y1[i]
+ * ix2 = x2[i] # <<<<<<<<<<<<<<
+ * iy2 = y2[i]
+ * iarea = areas[i]
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_x2.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_x2.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 50, __pyx_L1_error)
+ }
+ __pyx_v_ix2 = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_x2.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_x2.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":51
+ * iy1 = y1[i]
+ * ix2 = x2[i]
+ * iy2 = y2[i] # <<<<<<<<<<<<<<
+ * iarea = areas[i]
+ * for _j in range(_i + 1, ndets):
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_y2.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_y2.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 51, __pyx_L1_error)
+ }
+ __pyx_v_iy2 = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_y2.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_y2.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":52
+ * ix2 = x2[i]
+ * iy2 = y2[i]
+ * iarea = areas[i] # <<<<<<<<<<<<<<
+ * for _j in range(_i + 1, ndets):
+ * j = order[_j]
+ */
+ __pyx_t_19 = __pyx_v_i;
+ __pyx_t_20 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_areas.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_20 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_areas.diminfo[0].shape)) __pyx_t_20 = 0;
+ if (unlikely(__pyx_t_20 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_20);
+ __PYX_ERR(0, 52, __pyx_L1_error)
+ }
+ __pyx_v_iarea = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_areas.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_areas.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":53
+ * iy2 = y2[i]
+ * iarea = areas[i]
+ * for _j in range(_i + 1, ndets): # <<<<<<<<<<<<<<
+ * j = order[_j]
+ * if suppressed[j] == 1:
+ */
+ __pyx_t_20 = __pyx_v_ndets;
+ __pyx_t_23 = __pyx_t_20;
+ for (__pyx_t_24 = (__pyx_v__i + 1); __pyx_t_24 < __pyx_t_23; __pyx_t_24+=1) {
+ __pyx_v__j = __pyx_t_24;
+
+ /* "nms/cpu_nms.pyx":54
+ * iarea = areas[i]
+ * for _j in range(_i + 1, ndets):
+ * j = order[_j] # <<<<<<<<<<<<<<
+ * if suppressed[j] == 1:
+ * continue
+ */
+ __pyx_t_19 = __pyx_v__j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_order.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_order.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 54, __pyx_L1_error)
+ }
+ __pyx_v_j = (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_int_t *, __pyx_pybuffernd_order.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_order.diminfo[0].strides));
+
+ /* "nms/cpu_nms.pyx":55
+ * for _j in range(_i + 1, ndets):
+ * j = order[_j]
+ * if suppressed[j] == 1: # <<<<<<<<<<<<<<
+ * continue
+ * xx1 = max(ix1, x1[j])
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_suppressed.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_suppressed.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 55, __pyx_L1_error)
+ }
+ __pyx_t_21 = ((*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_int_t *, __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_suppressed.diminfo[0].strides)) == 1);
+ if (__pyx_t_21) {
+
+ /* "nms/cpu_nms.pyx":56
+ * j = order[_j]
+ * if suppressed[j] == 1:
+ * continue # <<<<<<<<<<<<<<
+ * xx1 = max(ix1, x1[j])
+ * yy1 = max(iy1, y1[j])
+ */
+ goto __pyx_L6_continue;
+
+ /* "nms/cpu_nms.pyx":55
+ * for _j in range(_i + 1, ndets):
+ * j = order[_j]
+ * if suppressed[j] == 1: # <<<<<<<<<<<<<<
+ * continue
+ * xx1 = max(ix1, x1[j])
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":57
+ * if suppressed[j] == 1:
+ * continue
+ * xx1 = max(ix1, x1[j]) # <<<<<<<<<<<<<<
+ * yy1 = max(iy1, y1[j])
+ * xx2 = min(ix2, x2[j])
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_x1.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_x1.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 57, __pyx_L1_error)
+ }
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_max(__pyx_v_ix1, (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_x1.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_x1.diminfo[0].strides))); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 57, __pyx_L1_error)
+ __pyx_v_xx1 = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":58
+ * continue
+ * xx1 = max(ix1, x1[j])
+ * yy1 = max(iy1, y1[j]) # <<<<<<<<<<<<<<
+ * xx2 = min(ix2, x2[j])
+ * yy2 = min(iy2, y2[j])
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_y1.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_y1.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 58, __pyx_L1_error)
+ }
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_max(__pyx_v_iy1, (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_y1.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_y1.diminfo[0].strides))); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 58, __pyx_L1_error)
+ __pyx_v_yy1 = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":59
+ * xx1 = max(ix1, x1[j])
+ * yy1 = max(iy1, y1[j])
+ * xx2 = min(ix2, x2[j]) # <<<<<<<<<<<<<<
+ * yy2 = min(iy2, y2[j])
+ * w = max(0.0, xx2 - xx1 + 1)
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_x2.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_x2.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 59, __pyx_L1_error)
+ }
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_min(__pyx_v_ix2, (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_x2.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_x2.diminfo[0].strides))); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 59, __pyx_L1_error)
+ __pyx_v_xx2 = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":60
+ * yy1 = max(iy1, y1[j])
+ * xx2 = min(ix2, x2[j])
+ * yy2 = min(iy2, y2[j]) # <<<<<<<<<<<<<<
+ * w = max(0.0, xx2 - xx1 + 1)
+ * h = max(0.0, yy2 - yy1 + 1)
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_y2.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_y2.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 60, __pyx_L1_error)
+ }
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_min(__pyx_v_iy2, (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_y2.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_y2.diminfo[0].strides))); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 60, __pyx_L1_error)
+ __pyx_v_yy2 = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":61
+ * xx2 = min(ix2, x2[j])
+ * yy2 = min(iy2, y2[j])
+ * w = max(0.0, xx2 - xx1 + 1) # <<<<<<<<<<<<<<
+ * h = max(0.0, yy2 - yy1 + 1)
+ * inter = w * h
+ */
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_max(0.0, ((__pyx_v_xx2 - __pyx_v_xx1) + 1.0)); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 61, __pyx_L1_error)
+ __pyx_v_w = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":62
+ * yy2 = min(iy2, y2[j])
+ * w = max(0.0, xx2 - xx1 + 1)
+ * h = max(0.0, yy2 - yy1 + 1) # <<<<<<<<<<<<<<
+ * inter = w * h
+ * ovr = inter / (iarea + areas[j] - inter)
+ */
+ __pyx_t_26 = __pyx_f_3nms_7cpu_nms_max(0.0, ((__pyx_v_yy2 - __pyx_v_yy1) + 1.0)); if (unlikely(__pyx_t_26 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 62, __pyx_L1_error)
+ __pyx_v_h = __pyx_t_26;
+
+ /* "nms/cpu_nms.pyx":63
+ * w = max(0.0, xx2 - xx1 + 1)
+ * h = max(0.0, yy2 - yy1 + 1)
+ * inter = w * h # <<<<<<<<<<<<<<
+ * ovr = inter / (iarea + areas[j] - inter)
+ * if ovr >= thresh:
+ */
+ __pyx_v_inter = (__pyx_v_w * __pyx_v_h);
+
+ /* "nms/cpu_nms.pyx":64
+ * h = max(0.0, yy2 - yy1 + 1)
+ * inter = w * h
+ * ovr = inter / (iarea + areas[j] - inter) # <<<<<<<<<<<<<<
+ * if ovr >= thresh:
+ * suppressed[j] = 1
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_areas.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_areas.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 64, __pyx_L1_error)
+ }
+ __pyx_t_26 = ((__pyx_v_iarea + (*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_float32_t *, __pyx_pybuffernd_areas.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_areas.diminfo[0].strides))) - __pyx_v_inter);
+ if (unlikely(__pyx_t_26 == 0)) {
+ PyErr_SetString(PyExc_ZeroDivisionError, "float division");
+ __PYX_ERR(0, 64, __pyx_L1_error)
+ }
+ __pyx_v_ovr = (__pyx_v_inter / __pyx_t_26);
+
+ /* "nms/cpu_nms.pyx":65
+ * inter = w * h
+ * ovr = inter / (iarea + areas[j] - inter)
+ * if ovr >= thresh: # <<<<<<<<<<<<<<
+ * suppressed[j] = 1
+ *
+ */
+ __pyx_t_14 = PyFloat_FromDouble(__pyx_v_ovr); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 65, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __pyx_t_8 = PyObject_RichCompare(__pyx_t_14, __pyx_v_thresh, Py_GE); __Pyx_XGOTREF(__pyx_t_8); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 65, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0;
+ __pyx_t_21 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_21 < 0))) __PYX_ERR(0, 65, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0;
+ if (__pyx_t_21) {
+
+ /* "nms/cpu_nms.pyx":66
+ * ovr = inter / (iarea + areas[j] - inter)
+ * if ovr >= thresh:
+ * suppressed[j] = 1 # <<<<<<<<<<<<<<
+ *
+ * return keep
+ */
+ __pyx_t_19 = __pyx_v_j;
+ __pyx_t_25 = -1;
+ if (__pyx_t_19 < 0) {
+ __pyx_t_19 += __pyx_pybuffernd_suppressed.diminfo[0].shape;
+ if (unlikely(__pyx_t_19 < 0)) __pyx_t_25 = 0;
+ } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_suppressed.diminfo[0].shape)) __pyx_t_25 = 0;
+ if (unlikely(__pyx_t_25 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_25);
+ __PYX_ERR(0, 66, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided1d(__pyx_t_5numpy_int_t *, __pyx_pybuffernd_suppressed.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_suppressed.diminfo[0].strides) = 1;
+
+ /* "nms/cpu_nms.pyx":65
+ * inter = w * h
+ * ovr = inter / (iarea + areas[j] - inter)
+ * if ovr >= thresh: # <<<<<<<<<<<<<<
+ * suppressed[j] = 1
+ *
+ */
+ }
+ __pyx_L6_continue:;
+ }
+ __pyx_L3_continue:;
+ }
+
+ /* "nms/cpu_nms.pyx":68
+ * suppressed[j] = 1
+ *
+ * return keep # <<<<<<<<<<<<<<
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(__pyx_v_keep);
+ __pyx_r = __pyx_v_keep;
+ goto __pyx_L0;
+
+ /* "nms/cpu_nms.pyx":17
+ * return a if a <= b else b
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_XDECREF(__pyx_t_7);
+ __Pyx_XDECREF(__pyx_t_8);
+ __Pyx_XDECREF(__pyx_t_13);
+ __Pyx_XDECREF(__pyx_t_14);
+ { PyObject *__pyx_type, *__pyx_value, *__pyx_tb;
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_areas.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_dets.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_order.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_scores.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_suppressed.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_x1.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_x2.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_y1.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_y2.rcbuffer->pybuffer);
+ __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);}
+ __Pyx_AddTraceback("nms.cpu_nms.cpu_nms", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = NULL;
+ goto __pyx_L2;
+ __pyx_L0:;
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_areas.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_dets.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_order.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_scores.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_suppressed.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_x1.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_x2.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_y1.rcbuffer->pybuffer);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_y2.rcbuffer->pybuffer);
+ __pyx_L2:;
+ __Pyx_XDECREF((PyObject *)__pyx_v_x1);
+ __Pyx_XDECREF((PyObject *)__pyx_v_y1);
+ __Pyx_XDECREF((PyObject *)__pyx_v_x2);
+ __Pyx_XDECREF((PyObject *)__pyx_v_y2);
+ __Pyx_XDECREF((PyObject *)__pyx_v_scores);
+ __Pyx_XDECREF((PyObject *)__pyx_v_areas);
+ __Pyx_XDECREF((PyObject *)__pyx_v_order);
+ __Pyx_XDECREF((PyObject *)__pyx_v_suppressed);
+ __Pyx_XDECREF(__pyx_v_keep);
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+/* "nms/cpu_nms.pyx":70
+ * return keep
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): # <<<<<<<<<<<<<<
+ * cdef unsigned int N = boxes.shape[0]
+ * cdef float iw, ih, box_area
+ */
+
+/* Python wrapper */
+static PyObject *__pyx_pw_3nms_7cpu_nms_3cpu_soft_nms(PyObject *__pyx_self,
+#if CYTHON_METH_FASTCALL
+PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
+#else
+PyObject *__pyx_args, PyObject *__pyx_kwds
+#endif
+); /*proto*/
+static PyMethodDef __pyx_mdef_3nms_7cpu_nms_3cpu_soft_nms = {"cpu_soft_nms", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_3nms_7cpu_nms_3cpu_soft_nms, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
+static PyObject *__pyx_pw_3nms_7cpu_nms_3cpu_soft_nms(PyObject *__pyx_self,
+#if CYTHON_METH_FASTCALL
+PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
+#else
+PyObject *__pyx_args, PyObject *__pyx_kwds
+#endif
+) {
+ PyArrayObject *__pyx_v_boxes = 0;
+ float __pyx_v_sigma;
+ float __pyx_v_Nt;
+ float __pyx_v_threshold;
+ unsigned int __pyx_v_method;
+ #if !CYTHON_METH_FASTCALL
+ CYTHON_UNUSED Py_ssize_t __pyx_nargs;
+ #endif
+ CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
+ PyObject* values[5] = {0,0,0,0,0};
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ PyObject *__pyx_r = 0;
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("cpu_soft_nms (wrapper)", 0);
+ #if !CYTHON_METH_FASTCALL
+ #if CYTHON_ASSUME_SAFE_MACROS
+ __pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
+ #else
+ __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
+ #endif
+ #endif
+ __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
+ {
+ PyObject **__pyx_pyargnames[] = {&__pyx_n_s_boxes,&__pyx_n_s_sigma,&__pyx_n_s_Nt,&__pyx_n_s_threshold,&__pyx_n_s_method,0};
+ if (__pyx_kwds) {
+ Py_ssize_t kw_args;
+ switch (__pyx_nargs) {
+ case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4);
+ CYTHON_FALLTHROUGH;
+ case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3);
+ CYTHON_FALLTHROUGH;
+ case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2);
+ CYTHON_FALLTHROUGH;
+ case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1);
+ CYTHON_FALLTHROUGH;
+ case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);
+ CYTHON_FALLTHROUGH;
+ case 0: break;
+ default: goto __pyx_L5_argtuple_error;
+ }
+ kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds);
+ switch (__pyx_nargs) {
+ case 0:
+ if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_boxes)) != 0)) {
+ (void)__Pyx_Arg_NewRef_FASTCALL(values[0]);
+ kw_args--;
+ }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ else goto __pyx_L5_argtuple_error;
+ CYTHON_FALLTHROUGH;
+ case 1:
+ if (kw_args > 0) {
+ PyObject* value = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_sigma);
+ if (value) { values[1] = __Pyx_Arg_NewRef_FASTCALL(value); kw_args--; }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ }
+ CYTHON_FALLTHROUGH;
+ case 2:
+ if (kw_args > 0) {
+ PyObject* value = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_Nt);
+ if (value) { values[2] = __Pyx_Arg_NewRef_FASTCALL(value); kw_args--; }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ }
+ CYTHON_FALLTHROUGH;
+ case 3:
+ if (kw_args > 0) {
+ PyObject* value = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_threshold);
+ if (value) { values[3] = __Pyx_Arg_NewRef_FASTCALL(value); kw_args--; }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ }
+ CYTHON_FALLTHROUGH;
+ case 4:
+ if (kw_args > 0) {
+ PyObject* value = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_method);
+ if (value) { values[4] = __Pyx_Arg_NewRef_FASTCALL(value); kw_args--; }
+ else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ }
+ }
+ if (unlikely(kw_args > 0)) {
+ const Py_ssize_t kwd_pos_args = __pyx_nargs;
+ if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "cpu_soft_nms") < 0)) __PYX_ERR(0, 70, __pyx_L3_error)
+ }
+ } else {
+ switch (__pyx_nargs) {
+ case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4);
+ CYTHON_FALLTHROUGH;
+ case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3);
+ CYTHON_FALLTHROUGH;
+ case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2);
+ CYTHON_FALLTHROUGH;
+ case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1);
+ CYTHON_FALLTHROUGH;
+ case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);
+ break;
+ default: goto __pyx_L5_argtuple_error;
+ }
+ }
+ __pyx_v_boxes = ((PyArrayObject *)values[0]);
+ if (values[1]) {
+ __pyx_v_sigma = __pyx_PyFloat_AsFloat(values[1]); if (unlikely((__pyx_v_sigma == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ } else {
+ __pyx_v_sigma = ((float)((double)0.5));
+ }
+ if (values[2]) {
+ __pyx_v_Nt = __pyx_PyFloat_AsFloat(values[2]); if (unlikely((__pyx_v_Nt == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ } else {
+ __pyx_v_Nt = ((float)((double)0.3));
+ }
+ if (values[3]) {
+ __pyx_v_threshold = __pyx_PyFloat_AsFloat(values[3]); if (unlikely((__pyx_v_threshold == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ } else {
+ __pyx_v_threshold = ((float)((double)0.001));
+ }
+ if (values[4]) {
+ __pyx_v_method = __Pyx_PyInt_As_unsigned_int(values[4]); if (unlikely((__pyx_v_method == (unsigned int)-1) && PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L3_error)
+ } else {
+ __pyx_v_method = ((unsigned int)((unsigned int)0));
+ }
+ }
+ goto __pyx_L6_skip;
+ __pyx_L5_argtuple_error:;
+ __Pyx_RaiseArgtupleInvalid("cpu_soft_nms", 0, 1, 5, __pyx_nargs); __PYX_ERR(0, 70, __pyx_L3_error)
+ __pyx_L6_skip:;
+ goto __pyx_L4_argument_unpacking_done;
+ __pyx_L3_error:;
+ {
+ Py_ssize_t __pyx_temp;
+ for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
+ __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);
+ }
+ }
+ __Pyx_AddTraceback("nms.cpu_nms.cpu_soft_nms", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_RefNannyFinishContext();
+ return NULL;
+ __pyx_L4_argument_unpacking_done:;
+ if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_boxes), __pyx_ptype_5numpy_ndarray, 1, "boxes", 0))) __PYX_ERR(0, 70, __pyx_L1_error)
+ __pyx_r = __pyx_pf_3nms_7cpu_nms_2cpu_soft_nms(__pyx_self, __pyx_v_boxes, __pyx_v_sigma, __pyx_v_Nt, __pyx_v_threshold, __pyx_v_method);
+
+ /* function exit code */
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __pyx_r = NULL;
+ __pyx_L0:;
+ {
+ Py_ssize_t __pyx_temp;
+ for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
+ __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);
+ }
+ }
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyObject *__pyx_pf_3nms_7cpu_nms_2cpu_soft_nms(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_boxes, float __pyx_v_sigma, float __pyx_v_Nt, float __pyx_v_threshold, unsigned int __pyx_v_method) {
+ unsigned int __pyx_v_N;
+ float __pyx_v_iw;
+ float __pyx_v_ih;
+ float __pyx_v_ua;
+ int __pyx_v_pos;
+ float __pyx_v_maxscore;
+ int __pyx_v_maxpos;
+ float __pyx_v_x1;
+ float __pyx_v_x2;
+ float __pyx_v_y1;
+ float __pyx_v_y2;
+ float __pyx_v_tx1;
+ float __pyx_v_tx2;
+ float __pyx_v_ty1;
+ float __pyx_v_ty2;
+ float __pyx_v_ts;
+ float __pyx_v_area;
+ float __pyx_v_weight;
+ float __pyx_v_ov;
+ PyObject *__pyx_v_i = NULL;
+ CYTHON_UNUSED PyObject *__pyx_v_s = NULL;
+ PyObject *__pyx_v_keep = NULL;
+ unsigned int __pyx_7genexpr__pyx_v_i;
+ __Pyx_LocalBuf_ND __pyx_pybuffernd_boxes;
+ __Pyx_Buffer __pyx_pybuffer_boxes;
+ PyObject *__pyx_r = NULL;
+ __Pyx_RefNannyDeclarations
+ npy_intp *__pyx_t_1;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ Py_ssize_t __pyx_t_4;
+ PyObject *(*__pyx_t_5)(PyObject *);
+ PyObject *__pyx_t_6 = NULL;
+ float __pyx_t_7;
+ int __pyx_t_8;
+ int __pyx_t_9;
+ Py_ssize_t __pyx_t_10;
+ Py_ssize_t __pyx_t_11;
+ __pyx_t_5numpy_float32_t __pyx_t_12;
+ __pyx_t_5numpy_float32_t __pyx_t_13;
+ PyObject *__pyx_t_14 = NULL;
+ PyObject *__pyx_t_15 = NULL;
+ unsigned int __pyx_t_16;
+ Py_ssize_t __pyx_t_17;
+ Py_ssize_t __pyx_t_18;
+ unsigned int __pyx_t_19;
+ unsigned int __pyx_t_20;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("cpu_soft_nms", 1);
+ __pyx_pybuffer_boxes.pybuffer.buf = NULL;
+ __pyx_pybuffer_boxes.refcount = 0;
+ __pyx_pybuffernd_boxes.data = NULL;
+ __pyx_pybuffernd_boxes.rcbuffer = &__pyx_pybuffer_boxes;
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_boxes.rcbuffer->pybuffer, (PyObject*)__pyx_v_boxes, &__Pyx_TypeInfo_float, PyBUF_FORMAT| PyBUF_STRIDES| PyBUF_WRITABLE, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 70, __pyx_L1_error)
+ }
+ __pyx_pybuffernd_boxes.diminfo[0].strides = __pyx_pybuffernd_boxes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_boxes.diminfo[0].shape = __pyx_pybuffernd_boxes.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_boxes.diminfo[1].strides = __pyx_pybuffernd_boxes.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_boxes.diminfo[1].shape = __pyx_pybuffernd_boxes.rcbuffer->pybuffer.shape[1];
+
+ /* "nms/cpu_nms.pyx":71
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):
+ * cdef unsigned int N = boxes.shape[0] # <<<<<<<<<<<<<<
+ * cdef float iw, ih, box_area
+ * cdef float ua
+ */
+ __pyx_t_1 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_boxes)); if (unlikely(__pyx_t_1 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 71, __pyx_L1_error)
+ __pyx_v_N = (__pyx_t_1[0]);
+
+ /* "nms/cpu_nms.pyx":74
+ * cdef float iw, ih, box_area
+ * cdef float ua
+ * cdef int pos = 0 # <<<<<<<<<<<<<<
+ * cdef float maxscore = 0
+ * cdef int maxpos = 0
+ */
+ __pyx_v_pos = 0;
+
+ /* "nms/cpu_nms.pyx":75
+ * cdef float ua
+ * cdef int pos = 0
+ * cdef float maxscore = 0 # <<<<<<<<<<<<<<
+ * cdef int maxpos = 0
+ * cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
+ */
+ __pyx_v_maxscore = 0.0;
+
+ /* "nms/cpu_nms.pyx":76
+ * cdef int pos = 0
+ * cdef float maxscore = 0
+ * cdef int maxpos = 0 # <<<<<<<<<<<<<<
+ * cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
+ *
+ */
+ __pyx_v_maxpos = 0;
+
+ /* "nms/cpu_nms.pyx":79
+ * cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
+ *
+ * for i in range(N): # <<<<<<<<<<<<<<
+ * maxscore = boxes[i, 4]
+ * maxpos = i
+ */
+ __pyx_t_2 = __Pyx_PyInt_From_unsigned_int(__pyx_v_N); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 79, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_3 = __Pyx_PyObject_CallOneArg(__pyx_builtin_range, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 79, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ if (likely(PyList_CheckExact(__pyx_t_3)) || PyTuple_CheckExact(__pyx_t_3)) {
+ __pyx_t_2 = __pyx_t_3; __Pyx_INCREF(__pyx_t_2);
+ __pyx_t_4 = 0;
+ __pyx_t_5 = NULL;
+ } else {
+ __pyx_t_4 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_t_3); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 79, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_5 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 79, __pyx_L1_error)
+ }
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ for (;;) {
+ if (likely(!__pyx_t_5)) {
+ if (likely(PyList_CheckExact(__pyx_t_2))) {
+ {
+ Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2);
+ #if !CYTHON_ASSUME_SAFE_MACROS
+ if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 79, __pyx_L1_error)
+ #endif
+ if (__pyx_t_4 >= __pyx_temp) break;
+ }
+ #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ __pyx_t_3 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_4); __Pyx_INCREF(__pyx_t_3); __pyx_t_4++; if (unlikely((0 < 0))) __PYX_ERR(0, 79, __pyx_L1_error)
+ #else
+ __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 79, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ #endif
+ } else {
+ {
+ Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2);
+ #if !CYTHON_ASSUME_SAFE_MACROS
+ if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 79, __pyx_L1_error)
+ #endif
+ if (__pyx_t_4 >= __pyx_temp) break;
+ }
+ #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_4); __Pyx_INCREF(__pyx_t_3); __pyx_t_4++; if (unlikely((0 < 0))) __PYX_ERR(0, 79, __pyx_L1_error)
+ #else
+ __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 79, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ #endif
+ }
+ } else {
+ __pyx_t_3 = __pyx_t_5(__pyx_t_2);
+ if (unlikely(!__pyx_t_3)) {
+ PyObject* exc_type = PyErr_Occurred();
+ if (exc_type) {
+ if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear();
+ else __PYX_ERR(0, 79, __pyx_L1_error)
+ }
+ break;
+ }
+ __Pyx_GOTREF(__pyx_t_3);
+ }
+ __Pyx_XDECREF_SET(__pyx_v_i, __pyx_t_3);
+ __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":80
+ *
+ * for i in range(N):
+ * maxscore = boxes[i, 4] # <<<<<<<<<<<<<<
+ * maxpos = i
+ *
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 80, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 80, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_4);
+ __Pyx_GIVEREF(__pyx_int_4);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_4)) __PYX_ERR(0, 80, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 80, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 80, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_maxscore = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":81
+ * for i in range(N):
+ * maxscore = boxes[i, 4]
+ * maxpos = i # <<<<<<<<<<<<<<
+ *
+ * tx1 = boxes[i,0]
+ */
+ __pyx_t_8 = __Pyx_PyInt_As_int(__pyx_v_i); if (unlikely((__pyx_t_8 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 81, __pyx_L1_error)
+ __pyx_v_maxpos = __pyx_t_8;
+
+ /* "nms/cpu_nms.pyx":83
+ * maxpos = i
+ *
+ * tx1 = boxes[i,0] # <<<<<<<<<<<<<<
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2]
+ */
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 83, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 83, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_0);
+ __Pyx_GIVEREF(__pyx_int_0);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_0)) __PYX_ERR(0, 83, __pyx_L1_error);
+ __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 83, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 83, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_tx1 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":84
+ *
+ * tx1 = boxes[i,0]
+ * ty1 = boxes[i,1] # <<<<<<<<<<<<<<
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3]
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 84, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 84, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_1);
+ __Pyx_GIVEREF(__pyx_int_1);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_1)) __PYX_ERR(0, 84, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 84, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 84, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_ty1 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":85
+ * tx1 = boxes[i,0]
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2] # <<<<<<<<<<<<<<
+ * ty2 = boxes[i,3]
+ * ts = boxes[i,4]
+ */
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 85, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 85, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_2);
+ __Pyx_GIVEREF(__pyx_int_2);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_2)) __PYX_ERR(0, 85, __pyx_L1_error);
+ __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 85, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 85, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_tx2 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":86
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3] # <<<<<<<<<<<<<<
+ * ts = boxes[i,4]
+ *
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 86, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 86, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_3);
+ __Pyx_GIVEREF(__pyx_int_3);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_3)) __PYX_ERR(0, 86, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 86, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 86, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_ty2 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":87
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3]
+ * ts = boxes[i,4] # <<<<<<<<<<<<<<
+ *
+ * pos = i + 1
+ */
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 87, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 87, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_4);
+ __Pyx_GIVEREF(__pyx_int_4);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_4)) __PYX_ERR(0, 87, __pyx_L1_error);
+ __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 87, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 87, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_ts = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":89
+ * ts = boxes[i,4]
+ *
+ * pos = i + 1 # <<<<<<<<<<<<<<
+ * # get max box
+ * while pos < N:
+ */
+ __pyx_t_3 = __Pyx_PyInt_AddObjC(__pyx_v_i, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 89, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_8 = __Pyx_PyInt_As_int(__pyx_t_3); if (unlikely((__pyx_t_8 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 89, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_pos = __pyx_t_8;
+
+ /* "nms/cpu_nms.pyx":91
+ * pos = i + 1
+ * # get max box
+ * while pos < N: # <<<<<<<<<<<<<<
+ * if maxscore < boxes[pos, 4]:
+ * maxscore = boxes[pos, 4]
+ */
+ while (1) {
+ __pyx_t_9 = (__pyx_v_pos < __pyx_v_N);
+ if (!__pyx_t_9) break;
+
+ /* "nms/cpu_nms.pyx":92
+ * # get max box
+ * while pos < N:
+ * if maxscore < boxes[pos, 4]: # <<<<<<<<<<<<<<
+ * maxscore = boxes[pos, 4]
+ * maxpos = pos
+ */
+ __pyx_t_10 = __pyx_v_pos;
+ __pyx_t_11 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 92, __pyx_L1_error)
+ }
+ __pyx_t_9 = (__pyx_v_maxscore < (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides)));
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":93
+ * while pos < N:
+ * if maxscore < boxes[pos, 4]:
+ * maxscore = boxes[pos, 4] # <<<<<<<<<<<<<<
+ * maxpos = pos
+ * pos = pos + 1
+ */
+ __pyx_t_11 = __pyx_v_pos;
+ __pyx_t_10 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 93, __pyx_L1_error)
+ }
+ __pyx_v_maxscore = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":94
+ * if maxscore < boxes[pos, 4]:
+ * maxscore = boxes[pos, 4]
+ * maxpos = pos # <<<<<<<<<<<<<<
+ * pos = pos + 1
+ *
+ */
+ __pyx_v_maxpos = __pyx_v_pos;
+
+ /* "nms/cpu_nms.pyx":92
+ * # get max box
+ * while pos < N:
+ * if maxscore < boxes[pos, 4]: # <<<<<<<<<<<<<<
+ * maxscore = boxes[pos, 4]
+ * maxpos = pos
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":95
+ * maxscore = boxes[pos, 4]
+ * maxpos = pos
+ * pos = pos + 1 # <<<<<<<<<<<<<<
+ *
+ * # add max box as a detection
+ */
+ __pyx_v_pos = (__pyx_v_pos + 1);
+ }
+
+ /* "nms/cpu_nms.pyx":98
+ *
+ * # add max box as a detection
+ * boxes[i,0] = boxes[maxpos,0] # <<<<<<<<<<<<<<
+ * boxes[i,1] = boxes[maxpos,1]
+ * boxes[i,2] = boxes[maxpos,2]
+ */
+ __pyx_t_10 = __pyx_v_maxpos;
+ __pyx_t_11 = 0;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 98, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 98, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 98, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 98, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_0);
+ __Pyx_GIVEREF(__pyx_int_0);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_0)) __PYX_ERR(0, 98, __pyx_L1_error);
+ if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6, __pyx_t_3) < 0))) __PYX_ERR(0, 98, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":99
+ * # add max box as a detection
+ * boxes[i,0] = boxes[maxpos,0]
+ * boxes[i,1] = boxes[maxpos,1] # <<<<<<<<<<<<<<
+ * boxes[i,2] = boxes[maxpos,2]
+ * boxes[i,3] = boxes[maxpos,3]
+ */
+ __pyx_t_11 = __pyx_v_maxpos;
+ __pyx_t_10 = 1;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 99, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 99, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 99, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 99, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_1);
+ __Pyx_GIVEREF(__pyx_int_1);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_1)) __PYX_ERR(0, 99, __pyx_L1_error);
+ if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6, __pyx_t_3) < 0))) __PYX_ERR(0, 99, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":100
+ * boxes[i,0] = boxes[maxpos,0]
+ * boxes[i,1] = boxes[maxpos,1]
+ * boxes[i,2] = boxes[maxpos,2] # <<<<<<<<<<<<<<
+ * boxes[i,3] = boxes[maxpos,3]
+ * boxes[i,4] = boxes[maxpos,4]
+ */
+ __pyx_t_10 = __pyx_v_maxpos;
+ __pyx_t_11 = 2;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 100, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 100, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 100, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 100, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_2);
+ __Pyx_GIVEREF(__pyx_int_2);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_2)) __PYX_ERR(0, 100, __pyx_L1_error);
+ if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6, __pyx_t_3) < 0))) __PYX_ERR(0, 100, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":101
+ * boxes[i,1] = boxes[maxpos,1]
+ * boxes[i,2] = boxes[maxpos,2]
+ * boxes[i,3] = boxes[maxpos,3] # <<<<<<<<<<<<<<
+ * boxes[i,4] = boxes[maxpos,4]
+ *
+ */
+ __pyx_t_11 = __pyx_v_maxpos;
+ __pyx_t_10 = 3;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 101, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 101, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 101, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 101, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_3);
+ __Pyx_GIVEREF(__pyx_int_3);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_3)) __PYX_ERR(0, 101, __pyx_L1_error);
+ if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6, __pyx_t_3) < 0))) __PYX_ERR(0, 101, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":102
+ * boxes[i,2] = boxes[maxpos,2]
+ * boxes[i,3] = boxes[maxpos,3]
+ * boxes[i,4] = boxes[maxpos,4] # <<<<<<<<<<<<<<
+ *
+ * # swap ith box with position of max box
+ */
+ __pyx_t_10 = __pyx_v_maxpos;
+ __pyx_t_11 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 102, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 102, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 102, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 102, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_4);
+ __Pyx_GIVEREF(__pyx_int_4);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_4)) __PYX_ERR(0, 102, __pyx_L1_error);
+ if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6, __pyx_t_3) < 0))) __PYX_ERR(0, 102, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+
+ /* "nms/cpu_nms.pyx":105
+ *
+ * # swap ith box with position of max box
+ * boxes[maxpos,0] = tx1 # <<<<<<<<<<<<<<
+ * boxes[maxpos,1] = ty1
+ * boxes[maxpos,2] = tx2
+ */
+ __pyx_t_11 = __pyx_v_maxpos;
+ __pyx_t_10 = 0;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 105, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides) = __pyx_v_tx1;
+
+ /* "nms/cpu_nms.pyx":106
+ * # swap ith box with position of max box
+ * boxes[maxpos,0] = tx1
+ * boxes[maxpos,1] = ty1 # <<<<<<<<<<<<<<
+ * boxes[maxpos,2] = tx2
+ * boxes[maxpos,3] = ty2
+ */
+ __pyx_t_10 = __pyx_v_maxpos;
+ __pyx_t_11 = 1;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 106, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides) = __pyx_v_ty1;
+
+ /* "nms/cpu_nms.pyx":107
+ * boxes[maxpos,0] = tx1
+ * boxes[maxpos,1] = ty1
+ * boxes[maxpos,2] = tx2 # <<<<<<<<<<<<<<
+ * boxes[maxpos,3] = ty2
+ * boxes[maxpos,4] = ts
+ */
+ __pyx_t_11 = __pyx_v_maxpos;
+ __pyx_t_10 = 2;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 107, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides) = __pyx_v_tx2;
+
+ /* "nms/cpu_nms.pyx":108
+ * boxes[maxpos,1] = ty1
+ * boxes[maxpos,2] = tx2
+ * boxes[maxpos,3] = ty2 # <<<<<<<<<<<<<<
+ * boxes[maxpos,4] = ts
+ *
+ */
+ __pyx_t_10 = __pyx_v_maxpos;
+ __pyx_t_11 = 3;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 108, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides) = __pyx_v_ty2;
+
+ /* "nms/cpu_nms.pyx":109
+ * boxes[maxpos,2] = tx2
+ * boxes[maxpos,3] = ty2
+ * boxes[maxpos,4] = ts # <<<<<<<<<<<<<<
+ *
+ * tx1 = boxes[i,0]
+ */
+ __pyx_t_11 = __pyx_v_maxpos;
+ __pyx_t_10 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 109, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides) = __pyx_v_ts;
+
+ /* "nms/cpu_nms.pyx":111
+ * boxes[maxpos,4] = ts
+ *
+ * tx1 = boxes[i,0] # <<<<<<<<<<<<<<
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2]
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 111, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 111, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_0);
+ __Pyx_GIVEREF(__pyx_int_0);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_0)) __PYX_ERR(0, 111, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 111, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 111, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_tx1 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":112
+ *
+ * tx1 = boxes[i,0]
+ * ty1 = boxes[i,1] # <<<<<<<<<<<<<<
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3]
+ */
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 112, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 112, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_1);
+ __Pyx_GIVEREF(__pyx_int_1);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_1)) __PYX_ERR(0, 112, __pyx_L1_error);
+ __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 112, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 112, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_ty1 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":113
+ * tx1 = boxes[i,0]
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2] # <<<<<<<<<<<<<<
+ * ty2 = boxes[i,3]
+ * ts = boxes[i,4]
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 113, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 113, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_2);
+ __Pyx_GIVEREF(__pyx_int_2);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_2)) __PYX_ERR(0, 113, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 113, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 113, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_tx2 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":114
+ * ty1 = boxes[i,1]
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3] # <<<<<<<<<<<<<<
+ * ts = boxes[i,4]
+ *
+ */
+ __pyx_t_6 = PyTuple_New(2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 114, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_v_i)) __PYX_ERR(0, 114, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_3);
+ __Pyx_GIVEREF(__pyx_int_3);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_int_3)) __PYX_ERR(0, 114, __pyx_L1_error);
+ __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_6); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 114, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 114, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_v_ty2 = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":115
+ * tx2 = boxes[i,2]
+ * ty2 = boxes[i,3]
+ * ts = boxes[i,4] # <<<<<<<<<<<<<<
+ *
+ * pos = i + 1
+ */
+ __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 115, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_INCREF(__pyx_v_i);
+ __Pyx_GIVEREF(__pyx_v_i);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_i)) __PYX_ERR(0, 115, __pyx_L1_error);
+ __Pyx_INCREF(__pyx_int_4);
+ __Pyx_GIVEREF(__pyx_int_4);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_4)) __PYX_ERR(0, 115, __pyx_L1_error);
+ __pyx_t_6 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_boxes), __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 115, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 115, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_ts = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":117
+ * ts = boxes[i,4]
+ *
+ * pos = i + 1 # <<<<<<<<<<<<<<
+ * # NMS iterations, note that N changes if detection boxes fall below threshold
+ * while pos < N:
+ */
+ __pyx_t_6 = __Pyx_PyInt_AddObjC(__pyx_v_i, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 117, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __pyx_t_8 = __Pyx_PyInt_As_int(__pyx_t_6); if (unlikely((__pyx_t_8 == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 117, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_pos = __pyx_t_8;
+
+ /* "nms/cpu_nms.pyx":119
+ * pos = i + 1
+ * # NMS iterations, note that N changes if detection boxes fall below threshold
+ * while pos < N: # <<<<<<<<<<<<<<
+ * x1 = boxes[pos, 0]
+ * y1 = boxes[pos, 1]
+ */
+ while (1) {
+ __pyx_t_9 = (__pyx_v_pos < __pyx_v_N);
+ if (!__pyx_t_9) break;
+
+ /* "nms/cpu_nms.pyx":120
+ * # NMS iterations, note that N changes if detection boxes fall below threshold
+ * while pos < N:
+ * x1 = boxes[pos, 0] # <<<<<<<<<<<<<<
+ * y1 = boxes[pos, 1]
+ * x2 = boxes[pos, 2]
+ */
+ __pyx_t_10 = __pyx_v_pos;
+ __pyx_t_11 = 0;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 120, __pyx_L1_error)
+ }
+ __pyx_v_x1 = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":121
+ * while pos < N:
+ * x1 = boxes[pos, 0]
+ * y1 = boxes[pos, 1] # <<<<<<<<<<<<<<
+ * x2 = boxes[pos, 2]
+ * y2 = boxes[pos, 3]
+ */
+ __pyx_t_11 = __pyx_v_pos;
+ __pyx_t_10 = 1;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 121, __pyx_L1_error)
+ }
+ __pyx_v_y1 = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":122
+ * x1 = boxes[pos, 0]
+ * y1 = boxes[pos, 1]
+ * x2 = boxes[pos, 2] # <<<<<<<<<<<<<<
+ * y2 = boxes[pos, 3]
+ * s = boxes[pos, 4]
+ */
+ __pyx_t_10 = __pyx_v_pos;
+ __pyx_t_11 = 2;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 122, __pyx_L1_error)
+ }
+ __pyx_v_x2 = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":123
+ * y1 = boxes[pos, 1]
+ * x2 = boxes[pos, 2]
+ * y2 = boxes[pos, 3] # <<<<<<<<<<<<<<
+ * s = boxes[pos, 4]
+ *
+ */
+ __pyx_t_11 = __pyx_v_pos;
+ __pyx_t_10 = 3;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 123, __pyx_L1_error)
+ }
+ __pyx_v_y2 = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":124
+ * x2 = boxes[pos, 2]
+ * y2 = boxes[pos, 3]
+ * s = boxes[pos, 4] # <<<<<<<<<<<<<<
+ *
+ * area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ */
+ __pyx_t_10 = __pyx_v_pos;
+ __pyx_t_11 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 124, __pyx_L1_error)
+ }
+ __pyx_t_6 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides))); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 124, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_XDECREF_SET(__pyx_v_s, __pyx_t_6);
+ __pyx_t_6 = 0;
+
+ /* "nms/cpu_nms.pyx":126
+ * s = boxes[pos, 4]
+ *
+ * area = (x2 - x1 + 1) * (y2 - y1 + 1) # <<<<<<<<<<<<<<
+ * iw = (min(tx2, x2) - max(tx1, x1) + 1)
+ * if iw > 0:
+ */
+ __pyx_v_area = (((__pyx_v_x2 - __pyx_v_x1) + 1.0) * ((__pyx_v_y2 - __pyx_v_y1) + 1.0));
+
+ /* "nms/cpu_nms.pyx":127
+ *
+ * area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ * iw = (min(tx2, x2) - max(tx1, x1) + 1) # <<<<<<<<<<<<<<
+ * if iw > 0:
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ */
+ __pyx_t_12 = __pyx_f_3nms_7cpu_nms_min(__pyx_v_tx2, __pyx_v_x2); if (unlikely(__pyx_t_12 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 127, __pyx_L1_error)
+ __pyx_t_13 = __pyx_f_3nms_7cpu_nms_max(__pyx_v_tx1, __pyx_v_x1); if (unlikely(__pyx_t_13 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 127, __pyx_L1_error)
+ __pyx_v_iw = ((__pyx_t_12 - __pyx_t_13) + 1.0);
+
+ /* "nms/cpu_nms.pyx":128
+ * area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ * iw = (min(tx2, x2) - max(tx1, x1) + 1)
+ * if iw > 0: # <<<<<<<<<<<<<<
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ * if ih > 0:
+ */
+ __pyx_t_9 = (__pyx_v_iw > 0.0);
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":129
+ * iw = (min(tx2, x2) - max(tx1, x1) + 1)
+ * if iw > 0:
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1) # <<<<<<<<<<<<<<
+ * if ih > 0:
+ * ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
+ */
+ __pyx_t_13 = __pyx_f_3nms_7cpu_nms_min(__pyx_v_ty2, __pyx_v_y2); if (unlikely(__pyx_t_13 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 129, __pyx_L1_error)
+ __pyx_t_12 = __pyx_f_3nms_7cpu_nms_max(__pyx_v_ty1, __pyx_v_y1); if (unlikely(__pyx_t_12 == ((__pyx_t_5numpy_float32_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 129, __pyx_L1_error)
+ __pyx_v_ih = ((__pyx_t_13 - __pyx_t_12) + 1.0);
+
+ /* "nms/cpu_nms.pyx":130
+ * if iw > 0:
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ * if ih > 0: # <<<<<<<<<<<<<<
+ * ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
+ * ov = iw * ih / ua #iou between max box and detection box
+ */
+ __pyx_t_9 = (__pyx_v_ih > 0.0);
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":131
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ * if ih > 0:
+ * ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) # <<<<<<<<<<<<<<
+ * ov = iw * ih / ua #iou between max box and detection box
+ *
+ */
+ __pyx_v_ua = ((double)(((((__pyx_v_tx2 - __pyx_v_tx1) + 1.0) * ((__pyx_v_ty2 - __pyx_v_ty1) + 1.0)) + __pyx_v_area) - (__pyx_v_iw * __pyx_v_ih)));
+
+ /* "nms/cpu_nms.pyx":132
+ * if ih > 0:
+ * ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
+ * ov = iw * ih / ua #iou between max box and detection box # <<<<<<<<<<<<<<
+ *
+ * if method == 1: # linear
+ */
+ __pyx_t_7 = (__pyx_v_iw * __pyx_v_ih);
+ if (unlikely(__pyx_v_ua == 0)) {
+ PyErr_SetString(PyExc_ZeroDivisionError, "float division");
+ __PYX_ERR(0, 132, __pyx_L1_error)
+ }
+ __pyx_v_ov = (__pyx_t_7 / __pyx_v_ua);
+
+ /* "nms/cpu_nms.pyx":134
+ * ov = iw * ih / ua #iou between max box and detection box
+ *
+ * if method == 1: # linear # <<<<<<<<<<<<<<
+ * if ov > Nt:
+ * weight = 1 - ov
+ */
+ switch (__pyx_v_method) {
+ case 1:
+
+ /* "nms/cpu_nms.pyx":135
+ *
+ * if method == 1: # linear
+ * if ov > Nt: # <<<<<<<<<<<<<<
+ * weight = 1 - ov
+ * else:
+ */
+ __pyx_t_9 = (__pyx_v_ov > __pyx_v_Nt);
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":136
+ * if method == 1: # linear
+ * if ov > Nt:
+ * weight = 1 - ov # <<<<<<<<<<<<<<
+ * else:
+ * weight = 1
+ */
+ __pyx_v_weight = (1.0 - __pyx_v_ov);
+
+ /* "nms/cpu_nms.pyx":135
+ *
+ * if method == 1: # linear
+ * if ov > Nt: # <<<<<<<<<<<<<<
+ * weight = 1 - ov
+ * else:
+ */
+ goto __pyx_L12;
+ }
+
+ /* "nms/cpu_nms.pyx":138
+ * weight = 1 - ov
+ * else:
+ * weight = 1 # <<<<<<<<<<<<<<
+ * elif method == 2: # gaussian
+ * weight = np.exp(-(ov * ov)/sigma)
+ */
+ /*else*/ {
+ __pyx_v_weight = 1.0;
+ }
+ __pyx_L12:;
+
+ /* "nms/cpu_nms.pyx":134
+ * ov = iw * ih / ua #iou between max box and detection box
+ *
+ * if method == 1: # linear # <<<<<<<<<<<<<<
+ * if ov > Nt:
+ * weight = 1 - ov
+ */
+ break;
+ case 2:
+
+ /* "nms/cpu_nms.pyx":140
+ * weight = 1
+ * elif method == 2: # gaussian
+ * weight = np.exp(-(ov * ov)/sigma) # <<<<<<<<<<<<<<
+ * else: # original NMS
+ * if ov > Nt:
+ */
+ __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_np); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 140, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_14 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_exp); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 140, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_14);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_7 = (-(__pyx_v_ov * __pyx_v_ov));
+ if (unlikely(__pyx_v_sigma == 0)) {
+ PyErr_SetString(PyExc_ZeroDivisionError, "float division");
+ __PYX_ERR(0, 140, __pyx_L1_error)
+ }
+ __pyx_t_3 = PyFloat_FromDouble((__pyx_t_7 / __pyx_v_sigma)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 140, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_15 = NULL;
+ __pyx_t_16 = 0;
+ #if CYTHON_UNPACK_METHODS
+ if (unlikely(PyMethod_Check(__pyx_t_14))) {
+ __pyx_t_15 = PyMethod_GET_SELF(__pyx_t_14);
+ if (likely(__pyx_t_15)) {
+ PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_14);
+ __Pyx_INCREF(__pyx_t_15);
+ __Pyx_INCREF(function);
+ __Pyx_DECREF_SET(__pyx_t_14, function);
+ __pyx_t_16 = 1;
+ }
+ }
+ #endif
+ {
+ PyObject *__pyx_callargs[2] = {__pyx_t_15, __pyx_t_3};
+ __pyx_t_6 = __Pyx_PyObject_FastCall(__pyx_t_14, __pyx_callargs+1-__pyx_t_16, 1+__pyx_t_16);
+ __Pyx_XDECREF(__pyx_t_15); __pyx_t_15 = 0;
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 140, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0;
+ }
+ __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_6); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 140, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ __pyx_v_weight = __pyx_t_7;
+
+ /* "nms/cpu_nms.pyx":139
+ * else:
+ * weight = 1
+ * elif method == 2: # gaussian # <<<<<<<<<<<<<<
+ * weight = np.exp(-(ov * ov)/sigma)
+ * else: # original NMS
+ */
+ break;
+ default:
+
+ /* "nms/cpu_nms.pyx":142
+ * weight = np.exp(-(ov * ov)/sigma)
+ * else: # original NMS
+ * if ov > Nt: # <<<<<<<<<<<<<<
+ * weight = 0
+ * else:
+ */
+ __pyx_t_9 = (__pyx_v_ov > __pyx_v_Nt);
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":143
+ * else: # original NMS
+ * if ov > Nt:
+ * weight = 0 # <<<<<<<<<<<<<<
+ * else:
+ * weight = 1
+ */
+ __pyx_v_weight = 0.0;
+
+ /* "nms/cpu_nms.pyx":142
+ * weight = np.exp(-(ov * ov)/sigma)
+ * else: # original NMS
+ * if ov > Nt: # <<<<<<<<<<<<<<
+ * weight = 0
+ * else:
+ */
+ goto __pyx_L13;
+ }
+
+ /* "nms/cpu_nms.pyx":145
+ * weight = 0
+ * else:
+ * weight = 1 # <<<<<<<<<<<<<<
+ *
+ * boxes[pos, 4] = weight*boxes[pos, 4]
+ */
+ /*else*/ {
+ __pyx_v_weight = 1.0;
+ }
+ __pyx_L13:;
+ break;
+ }
+
+ /* "nms/cpu_nms.pyx":147
+ * weight = 1
+ *
+ * boxes[pos, 4] = weight*boxes[pos, 4] # <<<<<<<<<<<<<<
+ *
+ * # if box score falls below threshold, discard the box by swapping with last box
+ */
+ __pyx_t_11 = __pyx_v_pos;
+ __pyx_t_10 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 147, __pyx_L1_error)
+ }
+ __pyx_t_17 = __pyx_v_pos;
+ __pyx_t_18 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 147, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[1].strides) = (__pyx_v_weight * (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides)));
+
+ /* "nms/cpu_nms.pyx":151
+ * # if box score falls below threshold, discard the box by swapping with last box
+ * # update N
+ * if boxes[pos, 4] < threshold: # <<<<<<<<<<<<<<
+ * boxes[pos,0] = boxes[N-1, 0]
+ * boxes[pos,1] = boxes[N-1, 1]
+ */
+ __pyx_t_10 = __pyx_v_pos;
+ __pyx_t_11 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 151, __pyx_L1_error)
+ }
+ __pyx_t_9 = ((*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides)) < __pyx_v_threshold);
+ if (__pyx_t_9) {
+
+ /* "nms/cpu_nms.pyx":152
+ * # update N
+ * if boxes[pos, 4] < threshold:
+ * boxes[pos,0] = boxes[N-1, 0] # <<<<<<<<<<<<<<
+ * boxes[pos,1] = boxes[N-1, 1]
+ * boxes[pos,2] = boxes[N-1, 2]
+ */
+ __pyx_t_11 = (__pyx_v_N - 1);
+ __pyx_t_10 = 0;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 152, __pyx_L1_error)
+ }
+ __pyx_t_18 = __pyx_v_pos;
+ __pyx_t_17 = 0;
+ __pyx_t_8 = -1;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 152, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[1].strides) = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":153
+ * if boxes[pos, 4] < threshold:
+ * boxes[pos,0] = boxes[N-1, 0]
+ * boxes[pos,1] = boxes[N-1, 1] # <<<<<<<<<<<<<<
+ * boxes[pos,2] = boxes[N-1, 2]
+ * boxes[pos,3] = boxes[N-1, 3]
+ */
+ __pyx_t_10 = (__pyx_v_N - 1);
+ __pyx_t_11 = 1;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 153, __pyx_L1_error)
+ }
+ __pyx_t_17 = __pyx_v_pos;
+ __pyx_t_18 = 1;
+ __pyx_t_8 = -1;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 153, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[1].strides) = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":154
+ * boxes[pos,0] = boxes[N-1, 0]
+ * boxes[pos,1] = boxes[N-1, 1]
+ * boxes[pos,2] = boxes[N-1, 2] # <<<<<<<<<<<<<<
+ * boxes[pos,3] = boxes[N-1, 3]
+ * boxes[pos,4] = boxes[N-1, 4]
+ */
+ __pyx_t_11 = (__pyx_v_N - 1);
+ __pyx_t_10 = 2;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 154, __pyx_L1_error)
+ }
+ __pyx_t_18 = __pyx_v_pos;
+ __pyx_t_17 = 2;
+ __pyx_t_8 = -1;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 154, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[1].strides) = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":155
+ * boxes[pos,1] = boxes[N-1, 1]
+ * boxes[pos,2] = boxes[N-1, 2]
+ * boxes[pos,3] = boxes[N-1, 3] # <<<<<<<<<<<<<<
+ * boxes[pos,4] = boxes[N-1, 4]
+ * N = N - 1
+ */
+ __pyx_t_10 = (__pyx_v_N - 1);
+ __pyx_t_11 = 3;
+ __pyx_t_8 = -1;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 155, __pyx_L1_error)
+ }
+ __pyx_t_17 = __pyx_v_pos;
+ __pyx_t_18 = 3;
+ __pyx_t_8 = -1;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 155, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[1].strides) = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":156
+ * boxes[pos,2] = boxes[N-1, 2]
+ * boxes[pos,3] = boxes[N-1, 3]
+ * boxes[pos,4] = boxes[N-1, 4] # <<<<<<<<<<<<<<
+ * N = N - 1
+ * pos = pos - 1
+ */
+ __pyx_t_11 = (__pyx_v_N - 1);
+ __pyx_t_10 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_11 < 0) {
+ __pyx_t_11 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_11 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_11 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_10 < 0) {
+ __pyx_t_10 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_10 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_10 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 156, __pyx_L1_error)
+ }
+ __pyx_t_18 = __pyx_v_pos;
+ __pyx_t_17 = 4;
+ __pyx_t_8 = -1;
+ if (__pyx_t_18 < 0) {
+ __pyx_t_18 += __pyx_pybuffernd_boxes.diminfo[0].shape;
+ if (unlikely(__pyx_t_18 < 0)) __pyx_t_8 = 0;
+ } else if (unlikely(__pyx_t_18 >= __pyx_pybuffernd_boxes.diminfo[0].shape)) __pyx_t_8 = 0;
+ if (__pyx_t_17 < 0) {
+ __pyx_t_17 += __pyx_pybuffernd_boxes.diminfo[1].shape;
+ if (unlikely(__pyx_t_17 < 0)) __pyx_t_8 = 1;
+ } else if (unlikely(__pyx_t_17 >= __pyx_pybuffernd_boxes.diminfo[1].shape)) __pyx_t_8 = 1;
+ if (unlikely(__pyx_t_8 != -1)) {
+ __Pyx_RaiseBufferIndexError(__pyx_t_8);
+ __PYX_ERR(0, 156, __pyx_L1_error)
+ }
+ *__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_18, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_17, __pyx_pybuffernd_boxes.diminfo[1].strides) = (*__Pyx_BufPtrStrided2d(float *, __pyx_pybuffernd_boxes.rcbuffer->pybuffer.buf, __pyx_t_11, __pyx_pybuffernd_boxes.diminfo[0].strides, __pyx_t_10, __pyx_pybuffernd_boxes.diminfo[1].strides));
+
+ /* "nms/cpu_nms.pyx":157
+ * boxes[pos,3] = boxes[N-1, 3]
+ * boxes[pos,4] = boxes[N-1, 4]
+ * N = N - 1 # <<<<<<<<<<<<<<
+ * pos = pos - 1
+ *
+ */
+ __pyx_v_N = (__pyx_v_N - 1);
+
+ /* "nms/cpu_nms.pyx":158
+ * boxes[pos,4] = boxes[N-1, 4]
+ * N = N - 1
+ * pos = pos - 1 # <<<<<<<<<<<<<<
+ *
+ * pos = pos + 1
+ */
+ __pyx_v_pos = (__pyx_v_pos - 1);
+
+ /* "nms/cpu_nms.pyx":151
+ * # if box score falls below threshold, discard the box by swapping with last box
+ * # update N
+ * if boxes[pos, 4] < threshold: # <<<<<<<<<<<<<<
+ * boxes[pos,0] = boxes[N-1, 0]
+ * boxes[pos,1] = boxes[N-1, 1]
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":130
+ * if iw > 0:
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ * if ih > 0: # <<<<<<<<<<<<<<
+ * ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
+ * ov = iw * ih / ua #iou between max box and detection box
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":128
+ * area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ * iw = (min(tx2, x2) - max(tx1, x1) + 1)
+ * if iw > 0: # <<<<<<<<<<<<<<
+ * ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ * if ih > 0:
+ */
+ }
+
+ /* "nms/cpu_nms.pyx":160
+ * pos = pos - 1
+ *
+ * pos = pos + 1 # <<<<<<<<<<<<<<
+ *
+ * keep = [i for i in range(N)]
+ */
+ __pyx_v_pos = (__pyx_v_pos + 1);
+ }
+
+ /* "nms/cpu_nms.pyx":79
+ * cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
+ *
+ * for i in range(N): # <<<<<<<<<<<<<<
+ * maxscore = boxes[i, 4]
+ * maxpos = i
+ */
+ }
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+
+ /* "nms/cpu_nms.pyx":162
+ * pos = pos + 1
+ *
+ * keep = [i for i in range(N)] # <<<<<<<<<<<<<<
+ * return keep
+ */
+ { /* enter inner scope */
+ __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 162, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_16 = __pyx_v_N;
+ __pyx_t_19 = __pyx_t_16;
+ for (__pyx_t_20 = 0; __pyx_t_20 < __pyx_t_19; __pyx_t_20+=1) {
+ __pyx_7genexpr__pyx_v_i = __pyx_t_20;
+ __pyx_t_6 = __Pyx_PyInt_From_unsigned_int(__pyx_7genexpr__pyx_v_i); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 162, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(0, 162, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ }
+ } /* exit inner scope */
+ __pyx_v_keep = ((PyObject*)__pyx_t_2);
+ __pyx_t_2 = 0;
+
+ /* "nms/cpu_nms.pyx":163
+ *
+ * keep = [i for i in range(N)]
+ * return keep # <<<<<<<<<<<<<<
+ */
+ __Pyx_XDECREF(__pyx_r);
+ __Pyx_INCREF(__pyx_v_keep);
+ __pyx_r = __pyx_v_keep;
+ goto __pyx_L0;
+
+ /* "nms/cpu_nms.pyx":70
+ * return keep
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): # <<<<<<<<<<<<<<
+ * cdef unsigned int N = boxes.shape[0]
+ * cdef float iw, ih, box_area
+ */
+
+ /* function exit code */
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_2);
+ __Pyx_XDECREF(__pyx_t_3);
+ __Pyx_XDECREF(__pyx_t_6);
+ __Pyx_XDECREF(__pyx_t_14);
+ __Pyx_XDECREF(__pyx_t_15);
+ { PyObject *__pyx_type, *__pyx_value, *__pyx_tb;
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb);
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_boxes.rcbuffer->pybuffer);
+ __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);}
+ __Pyx_AddTraceback("nms.cpu_nms.cpu_soft_nms", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __pyx_r = NULL;
+ goto __pyx_L2;
+ __pyx_L0:;
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_boxes.rcbuffer->pybuffer);
+ __pyx_L2:;
+ __Pyx_XDECREF(__pyx_v_i);
+ __Pyx_XDECREF(__pyx_v_s);
+ __Pyx_XDECREF(__pyx_v_keep);
+ __Pyx_XGIVEREF(__pyx_r);
+ __Pyx_RefNannyFinishContext();
+ return __pyx_r;
+}
+
+static PyMethodDef __pyx_methods[] = {
+ {0, 0, 0, 0}
+};
+#ifndef CYTHON_SMALL_CODE
+#if defined(__clang__)
+ #define CYTHON_SMALL_CODE
+#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))
+ #define CYTHON_SMALL_CODE __attribute__((cold))
+#else
+ #define CYTHON_SMALL_CODE
+#endif
+#endif
+/* #### Code section: pystring_table ### */
+
+static int __Pyx_CreateStringTabAndInitStrings(void) {
+ __Pyx_StringTabEntry __pyx_string_tab[] = {
+ {&__pyx_n_s_ImportError, __pyx_k_ImportError, sizeof(__pyx_k_ImportError), 0, 0, 1, 1},
+ {&__pyx_n_s_N, __pyx_k_N, sizeof(__pyx_k_N), 0, 0, 1, 1},
+ {&__pyx_n_s_Nt, __pyx_k_Nt, sizeof(__pyx_k_Nt), 0, 0, 1, 1},
+ {&__pyx_n_s__10, __pyx_k__10, sizeof(__pyx_k__10), 0, 0, 1, 1},
+ {&__pyx_n_s__15, __pyx_k__15, sizeof(__pyx_k__15), 0, 0, 1, 1},
+ {&__pyx_n_s_area, __pyx_k_area, sizeof(__pyx_k_area), 0, 0, 1, 1},
+ {&__pyx_n_s_areas, __pyx_k_areas, sizeof(__pyx_k_areas), 0, 0, 1, 1},
+ {&__pyx_n_s_argsort, __pyx_k_argsort, sizeof(__pyx_k_argsort), 0, 0, 1, 1},
+ {&__pyx_n_s_asyncio_coroutines, __pyx_k_asyncio_coroutines, sizeof(__pyx_k_asyncio_coroutines), 0, 0, 1, 1},
+ {&__pyx_n_s_box_area, __pyx_k_box_area, sizeof(__pyx_k_box_area), 0, 0, 1, 1},
+ {&__pyx_n_s_boxes, __pyx_k_boxes, sizeof(__pyx_k_boxes), 0, 0, 1, 1},
+ {&__pyx_n_s_class_getitem, __pyx_k_class_getitem, sizeof(__pyx_k_class_getitem), 0, 0, 1, 1},
+ {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1},
+ {&__pyx_n_s_cpu_nms, __pyx_k_cpu_nms, sizeof(__pyx_k_cpu_nms), 0, 0, 1, 1},
+ {&__pyx_n_s_cpu_soft_nms, __pyx_k_cpu_soft_nms, sizeof(__pyx_k_cpu_soft_nms), 0, 0, 1, 1},
+ {&__pyx_n_s_dets, __pyx_k_dets, sizeof(__pyx_k_dets), 0, 0, 1, 1},
+ {&__pyx_n_s_dtype, __pyx_k_dtype, sizeof(__pyx_k_dtype), 0, 0, 1, 1},
+ {&__pyx_n_s_exp, __pyx_k_exp, sizeof(__pyx_k_exp), 0, 0, 1, 1},
+ {&__pyx_n_s_h, __pyx_k_h, sizeof(__pyx_k_h), 0, 0, 1, 1},
+ {&__pyx_n_s_i, __pyx_k_i, sizeof(__pyx_k_i), 0, 0, 1, 1},
+ {&__pyx_n_s_i_2, __pyx_k_i_2, sizeof(__pyx_k_i_2), 0, 0, 1, 1},
+ {&__pyx_n_s_iarea, __pyx_k_iarea, sizeof(__pyx_k_iarea), 0, 0, 1, 1},
+ {&__pyx_n_s_ih, __pyx_k_ih, sizeof(__pyx_k_ih), 0, 0, 1, 1},
+ {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1},
+ {&__pyx_n_s_initializing, __pyx_k_initializing, sizeof(__pyx_k_initializing), 0, 0, 1, 1},
+ {&__pyx_n_s_int, __pyx_k_int, sizeof(__pyx_k_int), 0, 0, 1, 1},
+ {&__pyx_n_s_inter, __pyx_k_inter, sizeof(__pyx_k_inter), 0, 0, 1, 1},
+ {&__pyx_n_s_is_coroutine, __pyx_k_is_coroutine, sizeof(__pyx_k_is_coroutine), 0, 0, 1, 1},
+ {&__pyx_n_s_iw, __pyx_k_iw, sizeof(__pyx_k_iw), 0, 0, 1, 1},
+ {&__pyx_n_s_ix1, __pyx_k_ix1, sizeof(__pyx_k_ix1), 0, 0, 1, 1},
+ {&__pyx_n_s_ix2, __pyx_k_ix2, sizeof(__pyx_k_ix2), 0, 0, 1, 1},
+ {&__pyx_n_s_iy1, __pyx_k_iy1, sizeof(__pyx_k_iy1), 0, 0, 1, 1},
+ {&__pyx_n_s_iy2, __pyx_k_iy2, sizeof(__pyx_k_iy2), 0, 0, 1, 1},
+ {&__pyx_n_s_j, __pyx_k_j, sizeof(__pyx_k_j), 0, 0, 1, 1},
+ {&__pyx_n_s_j_2, __pyx_k_j_2, sizeof(__pyx_k_j_2), 0, 0, 1, 1},
+ {&__pyx_n_s_keep, __pyx_k_keep, sizeof(__pyx_k_keep), 0, 0, 1, 1},
+ {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1},
+ {&__pyx_n_s_maxpos, __pyx_k_maxpos, sizeof(__pyx_k_maxpos), 0, 0, 1, 1},
+ {&__pyx_n_s_maxscore, __pyx_k_maxscore, sizeof(__pyx_k_maxscore), 0, 0, 1, 1},
+ {&__pyx_n_s_method, __pyx_k_method, sizeof(__pyx_k_method), 0, 0, 1, 1},
+ {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1},
+ {&__pyx_n_s_ndets, __pyx_k_ndets, sizeof(__pyx_k_ndets), 0, 0, 1, 1},
+ {&__pyx_n_s_nms_cpu_nms, __pyx_k_nms_cpu_nms, sizeof(__pyx_k_nms_cpu_nms), 0, 0, 1, 1},
+ {&__pyx_kp_s_nms_cpu_nms_pyx, __pyx_k_nms_cpu_nms_pyx, sizeof(__pyx_k_nms_cpu_nms_pyx), 0, 0, 1, 0},
+ {&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1},
+ {&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1},
+ {&__pyx_kp_s_numpy_core_multiarray_failed_to, __pyx_k_numpy_core_multiarray_failed_to, sizeof(__pyx_k_numpy_core_multiarray_failed_to), 0, 0, 1, 0},
+ {&__pyx_kp_s_numpy_core_umath_failed_to_impor, __pyx_k_numpy_core_umath_failed_to_impor, sizeof(__pyx_k_numpy_core_umath_failed_to_impor), 0, 0, 1, 0},
+ {&__pyx_n_s_order, __pyx_k_order, sizeof(__pyx_k_order), 0, 0, 1, 1},
+ {&__pyx_n_s_ov, __pyx_k_ov, sizeof(__pyx_k_ov), 0, 0, 1, 1},
+ {&__pyx_n_s_ovr, __pyx_k_ovr, sizeof(__pyx_k_ovr), 0, 0, 1, 1},
+ {&__pyx_n_s_pos, __pyx_k_pos, sizeof(__pyx_k_pos), 0, 0, 1, 1},
+ {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1},
+ {&__pyx_n_s_s, __pyx_k_s, sizeof(__pyx_k_s), 0, 0, 1, 1},
+ {&__pyx_n_s_scores, __pyx_k_scores, sizeof(__pyx_k_scores), 0, 0, 1, 1},
+ {&__pyx_n_s_sigma, __pyx_k_sigma, sizeof(__pyx_k_sigma), 0, 0, 1, 1},
+ {&__pyx_n_s_spec, __pyx_k_spec, sizeof(__pyx_k_spec), 0, 0, 1, 1},
+ {&__pyx_n_s_suppressed, __pyx_k_suppressed, sizeof(__pyx_k_suppressed), 0, 0, 1, 1},
+ {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1},
+ {&__pyx_n_s_thresh, __pyx_k_thresh, sizeof(__pyx_k_thresh), 0, 0, 1, 1},
+ {&__pyx_n_s_threshold, __pyx_k_threshold, sizeof(__pyx_k_threshold), 0, 0, 1, 1},
+ {&__pyx_n_s_ts, __pyx_k_ts, sizeof(__pyx_k_ts), 0, 0, 1, 1},
+ {&__pyx_n_s_tx1, __pyx_k_tx1, sizeof(__pyx_k_tx1), 0, 0, 1, 1},
+ {&__pyx_n_s_tx2, __pyx_k_tx2, sizeof(__pyx_k_tx2), 0, 0, 1, 1},
+ {&__pyx_n_s_ty1, __pyx_k_ty1, sizeof(__pyx_k_ty1), 0, 0, 1, 1},
+ {&__pyx_n_s_ty2, __pyx_k_ty2, sizeof(__pyx_k_ty2), 0, 0, 1, 1},
+ {&__pyx_n_s_ua, __pyx_k_ua, sizeof(__pyx_k_ua), 0, 0, 1, 1},
+ {&__pyx_n_s_w, __pyx_k_w, sizeof(__pyx_k_w), 0, 0, 1, 1},
+ {&__pyx_n_s_weight, __pyx_k_weight, sizeof(__pyx_k_weight), 0, 0, 1, 1},
+ {&__pyx_n_s_x1, __pyx_k_x1, sizeof(__pyx_k_x1), 0, 0, 1, 1},
+ {&__pyx_n_s_x2, __pyx_k_x2, sizeof(__pyx_k_x2), 0, 0, 1, 1},
+ {&__pyx_n_s_xx1, __pyx_k_xx1, sizeof(__pyx_k_xx1), 0, 0, 1, 1},
+ {&__pyx_n_s_xx2, __pyx_k_xx2, sizeof(__pyx_k_xx2), 0, 0, 1, 1},
+ {&__pyx_n_s_y1, __pyx_k_y1, sizeof(__pyx_k_y1), 0, 0, 1, 1},
+ {&__pyx_n_s_y2, __pyx_k_y2, sizeof(__pyx_k_y2), 0, 0, 1, 1},
+ {&__pyx_n_s_yy1, __pyx_k_yy1, sizeof(__pyx_k_yy1), 0, 0, 1, 1},
+ {&__pyx_n_s_yy2, __pyx_k_yy2, sizeof(__pyx_k_yy2), 0, 0, 1, 1},
+ {&__pyx_n_s_zeros, __pyx_k_zeros, sizeof(__pyx_k_zeros), 0, 0, 1, 1},
+ {0, 0, 0, 0, 0, 0, 0}
+ };
+ return __Pyx_InitStrings(__pyx_string_tab);
+}
+/* #### Code section: cached_builtins ### */
+static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) {
+ __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 43, __pyx_L1_error)
+ __pyx_builtin_ImportError = __Pyx_GetBuiltinName(__pyx_n_s_ImportError); if (!__pyx_builtin_ImportError) __PYX_ERR(1, 986, __pyx_L1_error)
+ return 0;
+ __pyx_L1_error:;
+ return -1;
+}
+/* #### Code section: cached_constants ### */
+
+static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986
+ * __pyx_import_array()
+ * except Exception:
+ * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<<
+ *
+ * cdef inline int import_umath() except -1:
+ */
+ __pyx_tuple_ = PyTuple_Pack(1, __pyx_kp_s_numpy_core_multiarray_failed_to); if (unlikely(!__pyx_tuple_)) __PYX_ERR(1, 986, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple_);
+ __Pyx_GIVEREF(__pyx_tuple_);
+
+ /* "../../../../../../../../conda_envs/gagavatar/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992
+ * _import_umath()
+ * except Exception:
+ * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<<
+ *
+ * cdef inline int import_ufunc() except -1:
+ */
+ __pyx_tuple__2 = PyTuple_Pack(1, __pyx_kp_s_numpy_core_umath_failed_to_impor); if (unlikely(!__pyx_tuple__2)) __PYX_ERR(1, 992, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__2);
+ __Pyx_GIVEREF(__pyx_tuple__2);
+
+ /* "nms/cpu_nms.pyx":18
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ */
+ __pyx_slice__3 = PySlice_New(Py_None, Py_None, Py_None); if (unlikely(!__pyx_slice__3)) __PYX_ERR(0, 18, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_slice__3);
+ __Pyx_GIVEREF(__pyx_slice__3);
+ __pyx_tuple__4 = PyTuple_Pack(2, __pyx_slice__3, __pyx_int_0); if (unlikely(!__pyx_tuple__4)) __PYX_ERR(0, 18, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__4);
+ __Pyx_GIVEREF(__pyx_tuple__4);
+
+ /* "nms/cpu_nms.pyx":19
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ */
+ __pyx_tuple__5 = PyTuple_Pack(2, __pyx_slice__3, __pyx_int_1); if (unlikely(!__pyx_tuple__5)) __PYX_ERR(0, 19, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__5);
+ __Pyx_GIVEREF(__pyx_tuple__5);
+
+ /* "nms/cpu_nms.pyx":20
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+ */
+ __pyx_tuple__6 = PyTuple_Pack(2, __pyx_slice__3, __pyx_int_2); if (unlikely(!__pyx_tuple__6)) __PYX_ERR(0, 20, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__6);
+ __Pyx_GIVEREF(__pyx_tuple__6);
+
+ /* "nms/cpu_nms.pyx":21
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+ *
+ */
+ __pyx_tuple__7 = PyTuple_Pack(2, __pyx_slice__3, __pyx_int_3); if (unlikely(!__pyx_tuple__7)) __PYX_ERR(0, 21, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__7);
+ __Pyx_GIVEREF(__pyx_tuple__7);
+
+ /* "nms/cpu_nms.pyx":22
+ * cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ * cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ * cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] # <<<<<<<<<<<<<<
+ *
+ * cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ */
+ __pyx_tuple__8 = PyTuple_Pack(2, __pyx_slice__3, __pyx_int_4); if (unlikely(!__pyx_tuple__8)) __PYX_ERR(0, 22, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__8);
+ __Pyx_GIVEREF(__pyx_tuple__8);
+
+ /* "nms/cpu_nms.pyx":25
+ *
+ * cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ * cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] # <<<<<<<<<<<<<<
+ *
+ * cdef int ndets = dets.shape[0]
+ */
+ __pyx_slice__9 = PySlice_New(Py_None, Py_None, __pyx_int_neg_1); if (unlikely(!__pyx_slice__9)) __PYX_ERR(0, 25, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_slice__9);
+ __Pyx_GIVEREF(__pyx_slice__9);
+
+ /* "nms/cpu_nms.pyx":17
+ * return a if a <= b else b
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ */
+ __pyx_tuple__11 = PyTuple_Pack(29, __pyx_n_s_dets, __pyx_n_s_thresh, __pyx_n_s_x1, __pyx_n_s_y1, __pyx_n_s_x2, __pyx_n_s_y2, __pyx_n_s_scores, __pyx_n_s_areas, __pyx_n_s_order, __pyx_n_s_ndets, __pyx_n_s_suppressed, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_i_2, __pyx_n_s_j_2, __pyx_n_s_ix1, __pyx_n_s_iy1, __pyx_n_s_ix2, __pyx_n_s_iy2, __pyx_n_s_iarea, __pyx_n_s_xx1, __pyx_n_s_yy1, __pyx_n_s_xx2, __pyx_n_s_yy2, __pyx_n_s_w, __pyx_n_s_h, __pyx_n_s_inter, __pyx_n_s_ovr, __pyx_n_s_keep); if (unlikely(!__pyx_tuple__11)) __PYX_ERR(0, 17, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__11);
+ __Pyx_GIVEREF(__pyx_tuple__11);
+ __pyx_codeobj__12 = (PyObject*)__Pyx_PyCode_New(2, 0, 0, 29, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__11, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_nms_cpu_nms_pyx, __pyx_n_s_cpu_nms, 17, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__12)) __PYX_ERR(0, 17, __pyx_L1_error)
+
+ /* "nms/cpu_nms.pyx":70
+ * return keep
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): # <<<<<<<<<<<<<<
+ * cdef unsigned int N = boxes.shape[0]
+ * cdef float iw, ih, box_area
+ */
+ __pyx_tuple__13 = PyTuple_Pack(29, __pyx_n_s_boxes, __pyx_n_s_sigma, __pyx_n_s_Nt, __pyx_n_s_threshold, __pyx_n_s_method, __pyx_n_s_N, __pyx_n_s_iw, __pyx_n_s_ih, __pyx_n_s_box_area, __pyx_n_s_ua, __pyx_n_s_pos, __pyx_n_s_maxscore, __pyx_n_s_maxpos, __pyx_n_s_x1, __pyx_n_s_x2, __pyx_n_s_y1, __pyx_n_s_y2, __pyx_n_s_tx1, __pyx_n_s_tx2, __pyx_n_s_ty1, __pyx_n_s_ty2, __pyx_n_s_ts, __pyx_n_s_area, __pyx_n_s_weight, __pyx_n_s_ov, __pyx_n_s_i_2, __pyx_n_s_s, __pyx_n_s_keep, __pyx_n_s_i_2); if (unlikely(!__pyx_tuple__13)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_tuple__13);
+ __Pyx_GIVEREF(__pyx_tuple__13);
+ __pyx_codeobj__14 = (PyObject*)__Pyx_PyCode_New(5, 0, 0, 29, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__13, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_nms_cpu_nms_pyx, __pyx_n_s_cpu_soft_nms, 70, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__14)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_RefNannyFinishContext();
+ return 0;
+ __pyx_L1_error:;
+ __Pyx_RefNannyFinishContext();
+ return -1;
+}
+/* #### Code section: init_constants ### */
+
+static CYTHON_SMALL_CODE int __Pyx_InitConstants(void) {
+ if (__Pyx_CreateStringTabAndInitStrings() < 0) __PYX_ERR(0, 1, __pyx_L1_error);
+ __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_int_2 = PyInt_FromLong(2); if (unlikely(!__pyx_int_2)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_int_3 = PyInt_FromLong(3); if (unlikely(!__pyx_int_3)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_int_4 = PyInt_FromLong(4); if (unlikely(!__pyx_int_4)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_int_neg_1 = PyInt_FromLong(-1); if (unlikely(!__pyx_int_neg_1)) __PYX_ERR(0, 1, __pyx_L1_error)
+ return 0;
+ __pyx_L1_error:;
+ return -1;
+}
+/* #### Code section: init_globals ### */
+
+static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) {
+ /* NumpyImportArray.init */
+ /*
+ * Cython has automatically inserted a call to _import_array since
+ * you didn't include one when you cimported numpy. To disable this
+ * add the line
+ * numpy._import_array
+ */
+#ifdef NPY_FEATURE_VERSION
+#ifndef NO_IMPORT_ARRAY
+if (unlikely(_import_array() == -1)) {
+ PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import "
+ "(auto-generated because you didn't call 'numpy.import_array()' after cimporting numpy; "
+ "use 'numpy._import_array' to disable if you are certain you don't need it).");
+}
+#endif
+#endif
+
+if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error)
+
+ return 0;
+ __pyx_L1_error:;
+ return -1;
+}
+/* #### Code section: init_module ### */
+
+static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/
+static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/
+
+static int __Pyx_modinit_global_init_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0);
+ /*--- Global init code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+static int __Pyx_modinit_variable_export_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0);
+ /*--- Variable export code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+static int __Pyx_modinit_function_export_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0);
+ /*--- Function export code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+static int __Pyx_modinit_type_init_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0);
+ /*--- Type init code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+static int __Pyx_modinit_type_import_code(void) {
+ __Pyx_RefNannyDeclarations
+ PyObject *__pyx_t_1 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0);
+ /*--- Type import code ---*/
+ __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 9, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_ptype_7cpython_4type_type = __Pyx_ImportType_3_0_12(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "type",
+ #if defined(PYPY_VERSION_NUM) && PYPY_VERSION_NUM < 0x050B0000
+ sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyTypeObject),
+ #elif CYTHON_COMPILING_IN_LIMITED_API
+ sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyTypeObject),
+ #else
+ sizeof(PyHeapTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyHeapTypeObject),
+ #endif
+ __Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_7cpython_4type_type) __PYX_ERR(2, 9, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_t_1 = PyImport_ImportModule("numpy"); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 202, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_ptype_5numpy_dtype = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "dtype", sizeof(PyArray_Descr), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyArray_Descr),__Pyx_ImportType_CheckSize_Ignore_3_0_12); if (!__pyx_ptype_5numpy_dtype) __PYX_ERR(1, 202, __pyx_L1_error)
+ __pyx_ptype_5numpy_flatiter = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "flatiter", sizeof(PyArrayIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyArrayIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_12); if (!__pyx_ptype_5numpy_flatiter) __PYX_ERR(1, 225, __pyx_L1_error)
+ __pyx_ptype_5numpy_broadcast = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "broadcast", sizeof(PyArrayMultiIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyArrayMultiIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_12); if (!__pyx_ptype_5numpy_broadcast) __PYX_ERR(1, 229, __pyx_L1_error)
+ __pyx_ptype_5numpy_ndarray = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "ndarray", sizeof(PyArrayObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyArrayObject),__Pyx_ImportType_CheckSize_Ignore_3_0_12); if (!__pyx_ptype_5numpy_ndarray) __PYX_ERR(1, 238, __pyx_L1_error)
+ __pyx_ptype_5numpy_generic = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "generic", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_generic) __PYX_ERR(1, 812, __pyx_L1_error)
+ __pyx_ptype_5numpy_number = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "number", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_number) __PYX_ERR(1, 814, __pyx_L1_error)
+ __pyx_ptype_5numpy_integer = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "integer", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_integer) __PYX_ERR(1, 816, __pyx_L1_error)
+ __pyx_ptype_5numpy_signedinteger = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "signedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_signedinteger) __PYX_ERR(1, 818, __pyx_L1_error)
+ __pyx_ptype_5numpy_unsignedinteger = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "unsignedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_unsignedinteger) __PYX_ERR(1, 820, __pyx_L1_error)
+ __pyx_ptype_5numpy_inexact = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "inexact", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_inexact) __PYX_ERR(1, 822, __pyx_L1_error)
+ __pyx_ptype_5numpy_floating = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "floating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_floating) __PYX_ERR(1, 824, __pyx_L1_error)
+ __pyx_ptype_5numpy_complexfloating = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "complexfloating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_complexfloating) __PYX_ERR(1, 826, __pyx_L1_error)
+ __pyx_ptype_5numpy_flexible = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "flexible", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_flexible) __PYX_ERR(1, 828, __pyx_L1_error)
+ __pyx_ptype_5numpy_character = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "character", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_12); if (!__pyx_ptype_5numpy_character) __PYX_ERR(1, 830, __pyx_L1_error)
+ __pyx_ptype_5numpy_ufunc = __Pyx_ImportType_3_0_12(__pyx_t_1, "numpy", "ufunc", sizeof(PyUFuncObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_12(PyUFuncObject),__Pyx_ImportType_CheckSize_Ignore_3_0_12); if (!__pyx_ptype_5numpy_ufunc) __PYX_ERR(1, 868, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_RefNannyFinishContext();
+ return 0;
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_1);
+ __Pyx_RefNannyFinishContext();
+ return -1;
+}
+
+static int __Pyx_modinit_variable_import_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0);
+ /*--- Variable import code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+static int __Pyx_modinit_function_import_code(void) {
+ __Pyx_RefNannyDeclarations
+ __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0);
+ /*--- Function import code ---*/
+ __Pyx_RefNannyFinishContext();
+ return 0;
+}
+
+
+#if PY_MAJOR_VERSION >= 3
+#if CYTHON_PEP489_MULTI_PHASE_INIT
+static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/
+static int __pyx_pymod_exec_cpu_nms(PyObject* module); /*proto*/
+static PyModuleDef_Slot __pyx_moduledef_slots[] = {
+ {Py_mod_create, (void*)__pyx_pymod_create},
+ {Py_mod_exec, (void*)__pyx_pymod_exec_cpu_nms},
+ {0, NULL}
+};
+#endif
+
+#ifdef __cplusplus
+namespace {
+ struct PyModuleDef __pyx_moduledef =
+ #else
+ static struct PyModuleDef __pyx_moduledef =
+ #endif
+ {
+ PyModuleDef_HEAD_INIT,
+ "cpu_nms",
+ 0, /* m_doc */
+ #if CYTHON_PEP489_MULTI_PHASE_INIT
+ 0, /* m_size */
+ #elif CYTHON_USE_MODULE_STATE
+ sizeof(__pyx_mstate), /* m_size */
+ #else
+ -1, /* m_size */
+ #endif
+ __pyx_methods /* m_methods */,
+ #if CYTHON_PEP489_MULTI_PHASE_INIT
+ __pyx_moduledef_slots, /* m_slots */
+ #else
+ NULL, /* m_reload */
+ #endif
+ #if CYTHON_USE_MODULE_STATE
+ __pyx_m_traverse, /* m_traverse */
+ __pyx_m_clear, /* m_clear */
+ NULL /* m_free */
+ #else
+ NULL, /* m_traverse */
+ NULL, /* m_clear */
+ NULL /* m_free */
+ #endif
+ };
+ #ifdef __cplusplus
+} /* anonymous namespace */
+#endif
+#endif
+
+#ifndef CYTHON_NO_PYINIT_EXPORT
+#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC
+#elif PY_MAJOR_VERSION < 3
+#ifdef __cplusplus
+#define __Pyx_PyMODINIT_FUNC extern "C" void
+#else
+#define __Pyx_PyMODINIT_FUNC void
+#endif
+#else
+#ifdef __cplusplus
+#define __Pyx_PyMODINIT_FUNC extern "C" PyObject *
+#else
+#define __Pyx_PyMODINIT_FUNC PyObject *
+#endif
+#endif
+
+
+#if PY_MAJOR_VERSION < 3
+__Pyx_PyMODINIT_FUNC initcpu_nms(void) CYTHON_SMALL_CODE; /*proto*/
+__Pyx_PyMODINIT_FUNC initcpu_nms(void)
+#else
+__Pyx_PyMODINIT_FUNC PyInit_cpu_nms(void) CYTHON_SMALL_CODE; /*proto*/
+__Pyx_PyMODINIT_FUNC PyInit_cpu_nms(void)
+#if CYTHON_PEP489_MULTI_PHASE_INIT
+{
+ return PyModuleDef_Init(&__pyx_moduledef);
+}
+static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) {
+ #if PY_VERSION_HEX >= 0x030700A1
+ static PY_INT64_T main_interpreter_id = -1;
+ PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp);
+ if (main_interpreter_id == -1) {
+ main_interpreter_id = current_id;
+ return (unlikely(current_id == -1)) ? -1 : 0;
+ } else if (unlikely(main_interpreter_id != current_id))
+ #else
+ static PyInterpreterState *main_interpreter = NULL;
+ PyInterpreterState *current_interpreter = PyThreadState_Get()->interp;
+ if (!main_interpreter) {
+ main_interpreter = current_interpreter;
+ } else if (unlikely(main_interpreter != current_interpreter))
+ #endif
+ {
+ PyErr_SetString(
+ PyExc_ImportError,
+ "Interpreter change detected - this module can only be loaded into one interpreter per process.");
+ return -1;
+ }
+ return 0;
+}
+#if CYTHON_COMPILING_IN_LIMITED_API
+static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *module, const char* from_name, const char* to_name, int allow_none)
+#else
+static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none)
+#endif
+{
+ PyObject *value = PyObject_GetAttrString(spec, from_name);
+ int result = 0;
+ if (likely(value)) {
+ if (allow_none || value != Py_None) {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ result = PyModule_AddObject(module, to_name, value);
+#else
+ result = PyDict_SetItemString(moddict, to_name, value);
+#endif
+ }
+ Py_DECREF(value);
+ } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ PyErr_Clear();
+ } else {
+ result = -1;
+ }
+ return result;
+}
+static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def) {
+ PyObject *module = NULL, *moddict, *modname;
+ CYTHON_UNUSED_VAR(def);
+ if (__Pyx_check_single_interpreter())
+ return NULL;
+ if (__pyx_m)
+ return __Pyx_NewRef(__pyx_m);
+ modname = PyObject_GetAttrString(spec, "name");
+ if (unlikely(!modname)) goto bad;
+ module = PyModule_NewObject(modname);
+ Py_DECREF(modname);
+ if (unlikely(!module)) goto bad;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ moddict = module;
+#else
+ moddict = PyModule_GetDict(module);
+ if (unlikely(!moddict)) goto bad;
+#endif
+ if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad;
+ if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad;
+ if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad;
+ if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad;
+ return module;
+bad:
+ Py_XDECREF(module);
+ return NULL;
+}
+
+
+static CYTHON_SMALL_CODE int __pyx_pymod_exec_cpu_nms(PyObject *__pyx_pyinit_module)
+#endif
+#endif
+{
+ int stringtab_initialized = 0;
+ #if CYTHON_USE_MODULE_STATE
+ int pystate_addmodule_run = 0;
+ #endif
+ PyObject *__pyx_t_1 = NULL;
+ PyObject *__pyx_t_2 = NULL;
+ PyObject *__pyx_t_3 = NULL;
+ PyObject *__pyx_t_4 = NULL;
+ PyObject *__pyx_t_5 = NULL;
+ PyObject *__pyx_t_6 = NULL;
+ int __pyx_lineno = 0;
+ const char *__pyx_filename = NULL;
+ int __pyx_clineno = 0;
+ __Pyx_RefNannyDeclarations
+ #if CYTHON_PEP489_MULTI_PHASE_INIT
+ if (__pyx_m) {
+ if (__pyx_m == __pyx_pyinit_module) return 0;
+ PyErr_SetString(PyExc_RuntimeError, "Module 'cpu_nms' has already been imported. Re-initialisation is not supported.");
+ return -1;
+ }
+ #elif PY_MAJOR_VERSION >= 3
+ if (__pyx_m) return __Pyx_NewRef(__pyx_m);
+ #endif
+ /*--- Module creation code ---*/
+ #if CYTHON_PEP489_MULTI_PHASE_INIT
+ __pyx_m = __pyx_pyinit_module;
+ Py_INCREF(__pyx_m);
+ #else
+ #if PY_MAJOR_VERSION < 3
+ __pyx_m = Py_InitModule4("cpu_nms", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m);
+ if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error)
+ #elif CYTHON_USE_MODULE_STATE
+ __pyx_t_1 = PyModule_Create(&__pyx_moduledef); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)
+ {
+ int add_module_result = PyState_AddModule(__pyx_t_1, &__pyx_moduledef);
+ __pyx_t_1 = 0; /* transfer ownership from __pyx_t_1 to "cpu_nms" pseudovariable */
+ if (unlikely((add_module_result < 0))) __PYX_ERR(0, 1, __pyx_L1_error)
+ pystate_addmodule_run = 1;
+ }
+ #else
+ __pyx_m = PyModule_Create(&__pyx_moduledef);
+ if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #endif
+ CYTHON_UNUSED_VAR(__pyx_t_1);
+ __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error)
+ Py_INCREF(__pyx_d);
+ __pyx_b = __Pyx_PyImport_AddModuleRef(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_cython_runtime = __Pyx_PyImport_AddModuleRef((const char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error)
+ if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #if CYTHON_REFNANNY
+__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny");
+if (!__Pyx_RefNanny) {
+ PyErr_Clear();
+ __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny");
+ if (!__Pyx_RefNanny)
+ Py_FatalError("failed to import 'refnanny' module");
+}
+#endif
+ __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit_cpu_nms(void)", 0);
+ if (__Pyx_check_binary_version(__PYX_LIMITED_VERSION_HEX, __Pyx_get_runtime_version(), CYTHON_COMPILING_IN_LIMITED_API) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #ifdef __Pxy_PyFrame_Initialize_Offsets
+ __Pxy_PyFrame_Initialize_Offsets();
+ #endif
+ __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error)
+ #ifdef __Pyx_CyFunction_USED
+ if (__pyx_CyFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #ifdef __Pyx_FusedFunction_USED
+ if (__pyx_FusedFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #ifdef __Pyx_Coroutine_USED
+ if (__pyx_Coroutine_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #ifdef __Pyx_Generator_USED
+ if (__pyx_Generator_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #ifdef __Pyx_AsyncGen_USED
+ if (__pyx_AsyncGen_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ #ifdef __Pyx_StopAsyncIteration_USED
+ if (__pyx_StopAsyncIteration_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ /*--- Library function declarations ---*/
+ /*--- Threads initialization code ---*/
+ #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 && defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS
+ PyEval_InitThreads();
+ #endif
+ /*--- Initialize various global constants etc. ---*/
+ if (__Pyx_InitConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ stringtab_initialized = 1;
+ if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT)
+ if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+ if (__pyx_module_is_main_nms__cpu_nms) {
+ if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name, __pyx_n_s_main) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ }
+ #if PY_MAJOR_VERSION >= 3
+ {
+ PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error)
+ if (!PyDict_GetItemString(modules, "nms.cpu_nms")) {
+ if (unlikely((PyDict_SetItemString(modules, "nms.cpu_nms", __pyx_m) < 0))) __PYX_ERR(0, 1, __pyx_L1_error)
+ }
+ }
+ #endif
+ /*--- Builtin init code ---*/
+ if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ /*--- Constants init code ---*/
+ if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ /*--- Global type/function init code ---*/
+ (void)__Pyx_modinit_global_init_code();
+ (void)__Pyx_modinit_variable_export_code();
+ (void)__Pyx_modinit_function_export_code();
+ (void)__Pyx_modinit_type_init_code();
+ if (unlikely((__Pyx_modinit_type_import_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error)
+ (void)__Pyx_modinit_variable_import_code();
+ (void)__Pyx_modinit_function_import_code();
+ /*--- Execution code ---*/
+ #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED)
+ if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ #endif
+
+ /* "nms/cpu_nms.pyx":8
+ * # --------------------------------------------------------
+ *
+ * import numpy as np # <<<<<<<<<<<<<<
+ * cimport numpy as np
+ *
+ */
+ __pyx_t_2 = __Pyx_ImportDottedModule(__pyx_n_s_numpy, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 8, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_2) < 0) __PYX_ERR(0, 8, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+
+ /* "nms/cpu_nms.pyx":17
+ * return a if a <= b else b
+ *
+ * def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): # <<<<<<<<<<<<<<
+ * cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ * cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ */
+ __pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_3nms_7cpu_nms_1cpu_nms, 0, __pyx_n_s_cpu_nms, NULL, __pyx_n_s_nms_cpu_nms, __pyx_d, ((PyObject *)__pyx_codeobj__12)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 17, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ if (PyDict_SetItem(__pyx_d, __pyx_n_s_cpu_nms, __pyx_t_2) < 0) __PYX_ERR(0, 17, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+
+ /* "nms/cpu_nms.pyx":70
+ * return keep
+ *
+ * def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): # <<<<<<<<<<<<<<
+ * cdef unsigned int N = boxes.shape[0]
+ * cdef float iw, ih, box_area
+ */
+ __pyx_t_2 = PyFloat_FromDouble(((double)0.5)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_3 = PyFloat_FromDouble(((double)0.3)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_4 = PyFloat_FromDouble(((double)0.001)); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_4);
+ __pyx_t_5 = __Pyx_PyInt_From_unsigned_int(((unsigned int)0)); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __pyx_t_6 = PyTuple_New(4); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_6);
+ __Pyx_GIVEREF(__pyx_t_2);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 0, __pyx_t_2)) __PYX_ERR(0, 70, __pyx_L1_error);
+ __Pyx_GIVEREF(__pyx_t_3);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 1, __pyx_t_3)) __PYX_ERR(0, 70, __pyx_L1_error);
+ __Pyx_GIVEREF(__pyx_t_4);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 2, __pyx_t_4)) __PYX_ERR(0, 70, __pyx_L1_error);
+ __Pyx_GIVEREF(__pyx_t_5);
+ if (__Pyx_PyTuple_SET_ITEM(__pyx_t_6, 3, __pyx_t_5)) __PYX_ERR(0, 70, __pyx_L1_error);
+ __pyx_t_2 = 0;
+ __pyx_t_3 = 0;
+ __pyx_t_4 = 0;
+ __pyx_t_5 = 0;
+ __pyx_t_5 = __Pyx_CyFunction_New(&__pyx_mdef_3nms_7cpu_nms_3cpu_soft_nms, 0, __pyx_n_s_cpu_soft_nms, NULL, __pyx_n_s_nms_cpu_nms, __pyx_d, ((PyObject *)__pyx_codeobj__14)); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_CyFunction_SetDefaultsTuple(__pyx_t_5, __pyx_t_6);
+ __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
+ if (PyDict_SetItem(__pyx_d, __pyx_n_s_cpu_soft_nms, __pyx_t_5) < 0) __PYX_ERR(0, 70, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+
+ /* "nms/cpu_nms.pyx":1
+ * # -------------------------------------------------------- # <<<<<<<<<<<<<<
+ * # Fast R-CNN
+ * # Copyright (c) 2015 Microsoft
+ */
+ __pyx_t_5 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 1, __pyx_L1_error)
+ __Pyx_GOTREF(__pyx_t_5);
+ if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_5) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+
+ /*--- Wrapped vars code ---*/
+
+ goto __pyx_L0;
+ __pyx_L1_error:;
+ __Pyx_XDECREF(__pyx_t_2);
+ __Pyx_XDECREF(__pyx_t_3);
+ __Pyx_XDECREF(__pyx_t_4);
+ __Pyx_XDECREF(__pyx_t_5);
+ __Pyx_XDECREF(__pyx_t_6);
+ if (__pyx_m) {
+ if (__pyx_d && stringtab_initialized) {
+ __Pyx_AddTraceback("init nms.cpu_nms", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ }
+ #if !CYTHON_USE_MODULE_STATE
+ Py_CLEAR(__pyx_m);
+ #else
+ Py_DECREF(__pyx_m);
+ if (pystate_addmodule_run) {
+ PyObject *tp, *value, *tb;
+ PyErr_Fetch(&tp, &value, &tb);
+ PyState_RemoveModule(&__pyx_moduledef);
+ PyErr_Restore(tp, value, tb);
+ }
+ #endif
+ } else if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ImportError, "init nms.cpu_nms");
+ }
+ __pyx_L0:;
+ __Pyx_RefNannyFinishContext();
+ #if CYTHON_PEP489_MULTI_PHASE_INIT
+ return (__pyx_m != NULL) ? 0 : -1;
+ #elif PY_MAJOR_VERSION >= 3
+ return __pyx_m;
+ #else
+ return;
+ #endif
+}
+/* #### Code section: cleanup_globals ### */
+/* #### Code section: cleanup_module ### */
+/* #### Code section: main_method ### */
+/* #### Code section: utility_code_pragmas ### */
+#ifdef _MSC_VER
+#pragma warning( push )
+/* Warning 4127: conditional expression is constant
+ * Cython uses constant conditional expressions to allow in inline functions to be optimized at
+ * compile-time, so this warning is not useful
+ */
+#pragma warning( disable : 4127 )
+#endif
+
+
+
+/* #### Code section: utility_code_def ### */
+
+/* --- Runtime support code --- */
+/* Refnanny */
+#if CYTHON_REFNANNY
+static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) {
+ PyObject *m = NULL, *p = NULL;
+ void *r = NULL;
+ m = PyImport_ImportModule(modname);
+ if (!m) goto end;
+ p = PyObject_GetAttrString(m, "RefNannyAPI");
+ if (!p) goto end;
+ r = PyLong_AsVoidPtr(p);
+end:
+ Py_XDECREF(p);
+ Py_XDECREF(m);
+ return (__Pyx_RefNannyAPIStruct *)r;
+}
+#endif
+
+/* PyErrExceptionMatches */
+#if CYTHON_FAST_THREAD_STATE
+static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) {
+ Py_ssize_t i, n;
+ n = PyTuple_GET_SIZE(tuple);
+#if PY_MAJOR_VERSION >= 3
+ for (i=0; i= 0x030C00A6
+ PyObject *current_exception = tstate->current_exception;
+ if (unlikely(!current_exception)) return 0;
+ exc_type = (PyObject*) Py_TYPE(current_exception);
+ if (exc_type == err) return 1;
+#else
+ exc_type = tstate->curexc_type;
+ if (exc_type == err) return 1;
+ if (unlikely(!exc_type)) return 0;
+#endif
+ #if CYTHON_AVOID_BORROWED_REFS
+ Py_INCREF(exc_type);
+ #endif
+ if (unlikely(PyTuple_Check(err))) {
+ result = __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err);
+ } else {
+ result = __Pyx_PyErr_GivenExceptionMatches(exc_type, err);
+ }
+ #if CYTHON_AVOID_BORROWED_REFS
+ Py_DECREF(exc_type);
+ #endif
+ return result;
+}
+#endif
+
+/* PyErrFetchRestore */
+#if CYTHON_FAST_THREAD_STATE
+static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) {
+#if PY_VERSION_HEX >= 0x030C00A6
+ PyObject *tmp_value;
+ assert(type == NULL || (value != NULL && type == (PyObject*) Py_TYPE(value)));
+ if (value) {
+ #if CYTHON_COMPILING_IN_CPYTHON
+ if (unlikely(((PyBaseExceptionObject*) value)->traceback != tb))
+ #endif
+ PyException_SetTraceback(value, tb);
+ }
+ tmp_value = tstate->current_exception;
+ tstate->current_exception = value;
+ Py_XDECREF(tmp_value);
+ Py_XDECREF(type);
+ Py_XDECREF(tb);
+#else
+ PyObject *tmp_type, *tmp_value, *tmp_tb;
+ tmp_type = tstate->curexc_type;
+ tmp_value = tstate->curexc_value;
+ tmp_tb = tstate->curexc_traceback;
+ tstate->curexc_type = type;
+ tstate->curexc_value = value;
+ tstate->curexc_traceback = tb;
+ Py_XDECREF(tmp_type);
+ Py_XDECREF(tmp_value);
+ Py_XDECREF(tmp_tb);
+#endif
+}
+static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) {
+#if PY_VERSION_HEX >= 0x030C00A6
+ PyObject* exc_value;
+ exc_value = tstate->current_exception;
+ tstate->current_exception = 0;
+ *value = exc_value;
+ *type = NULL;
+ *tb = NULL;
+ if (exc_value) {
+ *type = (PyObject*) Py_TYPE(exc_value);
+ Py_INCREF(*type);
+ #if CYTHON_COMPILING_IN_CPYTHON
+ *tb = ((PyBaseExceptionObject*) exc_value)->traceback;
+ Py_XINCREF(*tb);
+ #else
+ *tb = PyException_GetTraceback(exc_value);
+ #endif
+ }
+#else
+ *type = tstate->curexc_type;
+ *value = tstate->curexc_value;
+ *tb = tstate->curexc_traceback;
+ tstate->curexc_type = 0;
+ tstate->curexc_value = 0;
+ tstate->curexc_traceback = 0;
+#endif
+}
+#endif
+
+/* PyObjectGetAttrStr */
+#if CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) {
+ PyTypeObject* tp = Py_TYPE(obj);
+ if (likely(tp->tp_getattro))
+ return tp->tp_getattro(obj, attr_name);
+#if PY_MAJOR_VERSION < 3
+ if (likely(tp->tp_getattr))
+ return tp->tp_getattr(obj, PyString_AS_STRING(attr_name));
+#endif
+ return PyObject_GetAttr(obj, attr_name);
+}
+#endif
+
+/* PyObjectGetAttrStrNoError */
+#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1
+static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) {
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError)))
+ __Pyx_PyErr_Clear();
+}
+#endif
+static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) {
+ PyObject *result;
+#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1
+ (void) PyObject_GetOptionalAttr(obj, attr_name, &result);
+ return result;
+#else
+#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1
+ PyTypeObject* tp = Py_TYPE(obj);
+ if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) {
+ return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1);
+ }
+#endif
+ result = __Pyx_PyObject_GetAttrStr(obj, attr_name);
+ if (unlikely(!result)) {
+ __Pyx_PyObject_GetAttrStr_ClearAttributeError();
+ }
+ return result;
+#endif
+}
+
+/* GetBuiltinName */
+static PyObject *__Pyx_GetBuiltinName(PyObject *name) {
+ PyObject* result = __Pyx_PyObject_GetAttrStrNoError(__pyx_b, name);
+ if (unlikely(!result) && !PyErr_Occurred()) {
+ PyErr_Format(PyExc_NameError,
+#if PY_MAJOR_VERSION >= 3
+ "name '%U' is not defined", name);
+#else
+ "name '%.200s' is not defined", PyString_AS_STRING(name));
+#endif
+ }
+ return result;
+}
+
+/* GetTopmostException */
+#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE
+static _PyErr_StackItem *
+__Pyx_PyErr_GetTopmostException(PyThreadState *tstate)
+{
+ _PyErr_StackItem *exc_info = tstate->exc_info;
+ while ((exc_info->exc_value == NULL || exc_info->exc_value == Py_None) &&
+ exc_info->previous_item != NULL)
+ {
+ exc_info = exc_info->previous_item;
+ }
+ return exc_info;
+}
+#endif
+
+/* SaveResetException */
+#if CYTHON_FAST_THREAD_STATE
+static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) {
+ #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4
+ _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate);
+ PyObject *exc_value = exc_info->exc_value;
+ if (exc_value == NULL || exc_value == Py_None) {
+ *value = NULL;
+ *type = NULL;
+ *tb = NULL;
+ } else {
+ *value = exc_value;
+ Py_INCREF(*value);
+ *type = (PyObject*) Py_TYPE(exc_value);
+ Py_INCREF(*type);
+ *tb = PyException_GetTraceback(exc_value);
+ }
+ #elif CYTHON_USE_EXC_INFO_STACK
+ _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate);
+ *type = exc_info->exc_type;
+ *value = exc_info->exc_value;
+ *tb = exc_info->exc_traceback;
+ Py_XINCREF(*type);
+ Py_XINCREF(*value);
+ Py_XINCREF(*tb);
+ #else
+ *type = tstate->exc_type;
+ *value = tstate->exc_value;
+ *tb = tstate->exc_traceback;
+ Py_XINCREF(*type);
+ Py_XINCREF(*value);
+ Py_XINCREF(*tb);
+ #endif
+}
+static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) {
+ #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4
+ _PyErr_StackItem *exc_info = tstate->exc_info;
+ PyObject *tmp_value = exc_info->exc_value;
+ exc_info->exc_value = value;
+ Py_XDECREF(tmp_value);
+ Py_XDECREF(type);
+ Py_XDECREF(tb);
+ #else
+ PyObject *tmp_type, *tmp_value, *tmp_tb;
+ #if CYTHON_USE_EXC_INFO_STACK
+ _PyErr_StackItem *exc_info = tstate->exc_info;
+ tmp_type = exc_info->exc_type;
+ tmp_value = exc_info->exc_value;
+ tmp_tb = exc_info->exc_traceback;
+ exc_info->exc_type = type;
+ exc_info->exc_value = value;
+ exc_info->exc_traceback = tb;
+ #else
+ tmp_type = tstate->exc_type;
+ tmp_value = tstate->exc_value;
+ tmp_tb = tstate->exc_traceback;
+ tstate->exc_type = type;
+ tstate->exc_value = value;
+ tstate->exc_traceback = tb;
+ #endif
+ Py_XDECREF(tmp_type);
+ Py_XDECREF(tmp_value);
+ Py_XDECREF(tmp_tb);
+ #endif
+}
+#endif
+
+/* GetException */
+#if CYTHON_FAST_THREAD_STATE
+static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb)
+#else
+static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb)
+#endif
+{
+ PyObject *local_type = NULL, *local_value, *local_tb = NULL;
+#if CYTHON_FAST_THREAD_STATE
+ PyObject *tmp_type, *tmp_value, *tmp_tb;
+ #if PY_VERSION_HEX >= 0x030C00A6
+ local_value = tstate->current_exception;
+ tstate->current_exception = 0;
+ if (likely(local_value)) {
+ local_type = (PyObject*) Py_TYPE(local_value);
+ Py_INCREF(local_type);
+ local_tb = PyException_GetTraceback(local_value);
+ }
+ #else
+ local_type = tstate->curexc_type;
+ local_value = tstate->curexc_value;
+ local_tb = tstate->curexc_traceback;
+ tstate->curexc_type = 0;
+ tstate->curexc_value = 0;
+ tstate->curexc_traceback = 0;
+ #endif
+#else
+ PyErr_Fetch(&local_type, &local_value, &local_tb);
+#endif
+ PyErr_NormalizeException(&local_type, &local_value, &local_tb);
+#if CYTHON_FAST_THREAD_STATE && PY_VERSION_HEX >= 0x030C00A6
+ if (unlikely(tstate->current_exception))
+#elif CYTHON_FAST_THREAD_STATE
+ if (unlikely(tstate->curexc_type))
+#else
+ if (unlikely(PyErr_Occurred()))
+#endif
+ goto bad;
+ #if PY_MAJOR_VERSION >= 3
+ if (local_tb) {
+ if (unlikely(PyException_SetTraceback(local_value, local_tb) < 0))
+ goto bad;
+ }
+ #endif
+ Py_XINCREF(local_tb);
+ Py_XINCREF(local_type);
+ Py_XINCREF(local_value);
+ *type = local_type;
+ *value = local_value;
+ *tb = local_tb;
+#if CYTHON_FAST_THREAD_STATE
+ #if CYTHON_USE_EXC_INFO_STACK
+ {
+ _PyErr_StackItem *exc_info = tstate->exc_info;
+ #if PY_VERSION_HEX >= 0x030B00a4
+ tmp_value = exc_info->exc_value;
+ exc_info->exc_value = local_value;
+ tmp_type = NULL;
+ tmp_tb = NULL;
+ Py_XDECREF(local_type);
+ Py_XDECREF(local_tb);
+ #else
+ tmp_type = exc_info->exc_type;
+ tmp_value = exc_info->exc_value;
+ tmp_tb = exc_info->exc_traceback;
+ exc_info->exc_type = local_type;
+ exc_info->exc_value = local_value;
+ exc_info->exc_traceback = local_tb;
+ #endif
+ }
+ #else
+ tmp_type = tstate->exc_type;
+ tmp_value = tstate->exc_value;
+ tmp_tb = tstate->exc_traceback;
+ tstate->exc_type = local_type;
+ tstate->exc_value = local_value;
+ tstate->exc_traceback = local_tb;
+ #endif
+ Py_XDECREF(tmp_type);
+ Py_XDECREF(tmp_value);
+ Py_XDECREF(tmp_tb);
+#else
+ PyErr_SetExcInfo(local_type, local_value, local_tb);
+#endif
+ return 0;
+bad:
+ *type = 0;
+ *value = 0;
+ *tb = 0;
+ Py_XDECREF(local_type);
+ Py_XDECREF(local_value);
+ Py_XDECREF(local_tb);
+ return -1;
+}
+
+/* PyObjectCall */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) {
+ PyObject *result;
+ ternaryfunc call = Py_TYPE(func)->tp_call;
+ if (unlikely(!call))
+ return PyObject_Call(func, arg, kw);
+ #if PY_MAJOR_VERSION < 3
+ if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object")))
+ return NULL;
+ #else
+ if (unlikely(Py_EnterRecursiveCall(" while calling a Python object")))
+ return NULL;
+ #endif
+ result = (*call)(func, arg, kw);
+ Py_LeaveRecursiveCall();
+ if (unlikely(!result) && unlikely(!PyErr_Occurred())) {
+ PyErr_SetString(
+ PyExc_SystemError,
+ "NULL result without error in PyObject_Call");
+ }
+ return result;
+}
+#endif
+
+/* RaiseException */
+#if PY_MAJOR_VERSION < 3
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+ __Pyx_PyThreadState_declare
+ CYTHON_UNUSED_VAR(cause);
+ Py_XINCREF(type);
+ if (!value || value == Py_None)
+ value = NULL;
+ else
+ Py_INCREF(value);
+ if (!tb || tb == Py_None)
+ tb = NULL;
+ else {
+ Py_INCREF(tb);
+ if (!PyTraceBack_Check(tb)) {
+ PyErr_SetString(PyExc_TypeError,
+ "raise: arg 3 must be a traceback or None");
+ goto raise_error;
+ }
+ }
+ if (PyType_Check(type)) {
+#if CYTHON_COMPILING_IN_PYPY
+ if (!value) {
+ Py_INCREF(Py_None);
+ value = Py_None;
+ }
+#endif
+ PyErr_NormalizeException(&type, &value, &tb);
+ } else {
+ if (value) {
+ PyErr_SetString(PyExc_TypeError,
+ "instance exception may not have a separate value");
+ goto raise_error;
+ }
+ value = type;
+ type = (PyObject*) Py_TYPE(type);
+ Py_INCREF(type);
+ if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) {
+ PyErr_SetString(PyExc_TypeError,
+ "raise: exception class must be a subclass of BaseException");
+ goto raise_error;
+ }
+ }
+ __Pyx_PyThreadState_assign
+ __Pyx_ErrRestore(type, value, tb);
+ return;
+raise_error:
+ Py_XDECREF(value);
+ Py_XDECREF(type);
+ Py_XDECREF(tb);
+ return;
+}
+#else
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+ PyObject* owned_instance = NULL;
+ if (tb == Py_None) {
+ tb = 0;
+ } else if (tb && !PyTraceBack_Check(tb)) {
+ PyErr_SetString(PyExc_TypeError,
+ "raise: arg 3 must be a traceback or None");
+ goto bad;
+ }
+ if (value == Py_None)
+ value = 0;
+ if (PyExceptionInstance_Check(type)) {
+ if (value) {
+ PyErr_SetString(PyExc_TypeError,
+ "instance exception may not have a separate value");
+ goto bad;
+ }
+ value = type;
+ type = (PyObject*) Py_TYPE(value);
+ } else if (PyExceptionClass_Check(type)) {
+ PyObject *instance_class = NULL;
+ if (value && PyExceptionInstance_Check(value)) {
+ instance_class = (PyObject*) Py_TYPE(value);
+ if (instance_class != type) {
+ int is_subclass = PyObject_IsSubclass(instance_class, type);
+ if (!is_subclass) {
+ instance_class = NULL;
+ } else if (unlikely(is_subclass == -1)) {
+ goto bad;
+ } else {
+ type = instance_class;
+ }
+ }
+ }
+ if (!instance_class) {
+ PyObject *args;
+ if (!value)
+ args = PyTuple_New(0);
+ else if (PyTuple_Check(value)) {
+ Py_INCREF(value);
+ args = value;
+ } else
+ args = PyTuple_Pack(1, value);
+ if (!args)
+ goto bad;
+ owned_instance = PyObject_Call(type, args, NULL);
+ Py_DECREF(args);
+ if (!owned_instance)
+ goto bad;
+ value = owned_instance;
+ if (!PyExceptionInstance_Check(value)) {
+ PyErr_Format(PyExc_TypeError,
+ "calling %R should have returned an instance of "
+ "BaseException, not %R",
+ type, Py_TYPE(value));
+ goto bad;
+ }
+ }
+ } else {
+ PyErr_SetString(PyExc_TypeError,
+ "raise: exception class must be a subclass of BaseException");
+ goto bad;
+ }
+ if (cause) {
+ PyObject *fixed_cause;
+ if (cause == Py_None) {
+ fixed_cause = NULL;
+ } else if (PyExceptionClass_Check(cause)) {
+ fixed_cause = PyObject_CallObject(cause, NULL);
+ if (fixed_cause == NULL)
+ goto bad;
+ } else if (PyExceptionInstance_Check(cause)) {
+ fixed_cause = cause;
+ Py_INCREF(fixed_cause);
+ } else {
+ PyErr_SetString(PyExc_TypeError,
+ "exception causes must derive from "
+ "BaseException");
+ goto bad;
+ }
+ PyException_SetCause(value, fixed_cause);
+ }
+ PyErr_SetObject(type, value);
+ if (tb) {
+ #if PY_VERSION_HEX >= 0x030C00A6
+ PyException_SetTraceback(value, tb);
+ #elif CYTHON_FAST_THREAD_STATE
+ PyThreadState *tstate = __Pyx_PyThreadState_Current;
+ PyObject* tmp_tb = tstate->curexc_traceback;
+ if (tb != tmp_tb) {
+ Py_INCREF(tb);
+ tstate->curexc_traceback = tb;
+ Py_XDECREF(tmp_tb);
+ }
+#else
+ PyObject *tmp_type, *tmp_value, *tmp_tb;
+ PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb);
+ Py_INCREF(tb);
+ PyErr_Restore(tmp_type, tmp_value, tb);
+ Py_XDECREF(tmp_tb);
+#endif
+ }
+bad:
+ Py_XDECREF(owned_instance);
+ return;
+}
+#endif
+
+/* TupleAndListFromArray */
+#if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE void __Pyx_copy_object_array(PyObject *const *CYTHON_RESTRICT src, PyObject** CYTHON_RESTRICT dest, Py_ssize_t length) {
+ PyObject *v;
+ Py_ssize_t i;
+ for (i = 0; i < length; i++) {
+ v = dest[i] = src[i];
+ Py_INCREF(v);
+ }
+}
+static CYTHON_INLINE PyObject *
+__Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n)
+{
+ PyObject *res;
+ if (n <= 0) {
+ Py_INCREF(__pyx_empty_tuple);
+ return __pyx_empty_tuple;
+ }
+ res = PyTuple_New(n);
+ if (unlikely(res == NULL)) return NULL;
+ __Pyx_copy_object_array(src, ((PyTupleObject*)res)->ob_item, n);
+ return res;
+}
+static CYTHON_INLINE PyObject *
+__Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n)
+{
+ PyObject *res;
+ if (n <= 0) {
+ return PyList_New(0);
+ }
+ res = PyList_New(n);
+ if (unlikely(res == NULL)) return NULL;
+ __Pyx_copy_object_array(src, ((PyListObject*)res)->ob_item, n);
+ return res;
+}
+#endif
+
+/* BytesEquals */
+static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) {
+#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API
+ return PyObject_RichCompareBool(s1, s2, equals);
+#else
+ if (s1 == s2) {
+ return (equals == Py_EQ);
+ } else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) {
+ const char *ps1, *ps2;
+ Py_ssize_t length = PyBytes_GET_SIZE(s1);
+ if (length != PyBytes_GET_SIZE(s2))
+ return (equals == Py_NE);
+ ps1 = PyBytes_AS_STRING(s1);
+ ps2 = PyBytes_AS_STRING(s2);
+ if (ps1[0] != ps2[0]) {
+ return (equals == Py_NE);
+ } else if (length == 1) {
+ return (equals == Py_EQ);
+ } else {
+ int result;
+#if CYTHON_USE_UNICODE_INTERNALS && (PY_VERSION_HEX < 0x030B0000)
+ Py_hash_t hash1, hash2;
+ hash1 = ((PyBytesObject*)s1)->ob_shash;
+ hash2 = ((PyBytesObject*)s2)->ob_shash;
+ if (hash1 != hash2 && hash1 != -1 && hash2 != -1) {
+ return (equals == Py_NE);
+ }
+#endif
+ result = memcmp(ps1, ps2, (size_t)length);
+ return (equals == Py_EQ) ? (result == 0) : (result != 0);
+ }
+ } else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) {
+ return (equals == Py_NE);
+ } else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) {
+ return (equals == Py_NE);
+ } else {
+ int result;
+ PyObject* py_result = PyObject_RichCompare(s1, s2, equals);
+ if (!py_result)
+ return -1;
+ result = __Pyx_PyObject_IsTrue(py_result);
+ Py_DECREF(py_result);
+ return result;
+ }
+#endif
+}
+
+/* UnicodeEquals */
+static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals) {
+#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API
+ return PyObject_RichCompareBool(s1, s2, equals);
+#else
+#if PY_MAJOR_VERSION < 3
+ PyObject* owned_ref = NULL;
+#endif
+ int s1_is_unicode, s2_is_unicode;
+ if (s1 == s2) {
+ goto return_eq;
+ }
+ s1_is_unicode = PyUnicode_CheckExact(s1);
+ s2_is_unicode = PyUnicode_CheckExact(s2);
+#if PY_MAJOR_VERSION < 3
+ if ((s1_is_unicode & (!s2_is_unicode)) && PyString_CheckExact(s2)) {
+ owned_ref = PyUnicode_FromObject(s2);
+ if (unlikely(!owned_ref))
+ return -1;
+ s2 = owned_ref;
+ s2_is_unicode = 1;
+ } else if ((s2_is_unicode & (!s1_is_unicode)) && PyString_CheckExact(s1)) {
+ owned_ref = PyUnicode_FromObject(s1);
+ if (unlikely(!owned_ref))
+ return -1;
+ s1 = owned_ref;
+ s1_is_unicode = 1;
+ } else if (((!s2_is_unicode) & (!s1_is_unicode))) {
+ return __Pyx_PyBytes_Equals(s1, s2, equals);
+ }
+#endif
+ if (s1_is_unicode & s2_is_unicode) {
+ Py_ssize_t length;
+ int kind;
+ void *data1, *data2;
+ if (unlikely(__Pyx_PyUnicode_READY(s1) < 0) || unlikely(__Pyx_PyUnicode_READY(s2) < 0))
+ return -1;
+ length = __Pyx_PyUnicode_GET_LENGTH(s1);
+ if (length != __Pyx_PyUnicode_GET_LENGTH(s2)) {
+ goto return_ne;
+ }
+#if CYTHON_USE_UNICODE_INTERNALS
+ {
+ Py_hash_t hash1, hash2;
+ #if CYTHON_PEP393_ENABLED
+ hash1 = ((PyASCIIObject*)s1)->hash;
+ hash2 = ((PyASCIIObject*)s2)->hash;
+ #else
+ hash1 = ((PyUnicodeObject*)s1)->hash;
+ hash2 = ((PyUnicodeObject*)s2)->hash;
+ #endif
+ if (hash1 != hash2 && hash1 != -1 && hash2 != -1) {
+ goto return_ne;
+ }
+ }
+#endif
+ kind = __Pyx_PyUnicode_KIND(s1);
+ if (kind != __Pyx_PyUnicode_KIND(s2)) {
+ goto return_ne;
+ }
+ data1 = __Pyx_PyUnicode_DATA(s1);
+ data2 = __Pyx_PyUnicode_DATA(s2);
+ if (__Pyx_PyUnicode_READ(kind, data1, 0) != __Pyx_PyUnicode_READ(kind, data2, 0)) {
+ goto return_ne;
+ } else if (length == 1) {
+ goto return_eq;
+ } else {
+ int result = memcmp(data1, data2, (size_t)(length * kind));
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(owned_ref);
+ #endif
+ return (equals == Py_EQ) ? (result == 0) : (result != 0);
+ }
+ } else if ((s1 == Py_None) & s2_is_unicode) {
+ goto return_ne;
+ } else if ((s2 == Py_None) & s1_is_unicode) {
+ goto return_ne;
+ } else {
+ int result;
+ PyObject* py_result = PyObject_RichCompare(s1, s2, equals);
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(owned_ref);
+ #endif
+ if (!py_result)
+ return -1;
+ result = __Pyx_PyObject_IsTrue(py_result);
+ Py_DECREF(py_result);
+ return result;
+ }
+return_eq:
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(owned_ref);
+ #endif
+ return (equals == Py_EQ);
+return_ne:
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(owned_ref);
+ #endif
+ return (equals == Py_NE);
+#endif
+}
+
+/* fastcall */
+#if CYTHON_METH_FASTCALL
+static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s)
+{
+ Py_ssize_t i, n = PyTuple_GET_SIZE(kwnames);
+ for (i = 0; i < n; i++)
+ {
+ if (s == PyTuple_GET_ITEM(kwnames, i)) return kwvalues[i];
+ }
+ for (i = 0; i < n; i++)
+ {
+ int eq = __Pyx_PyUnicode_Equals(s, PyTuple_GET_ITEM(kwnames, i), Py_EQ);
+ if (unlikely(eq != 0)) {
+ if (unlikely(eq < 0)) return NULL;
+ return kwvalues[i];
+ }
+ }
+ return NULL;
+}
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000
+CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues) {
+ Py_ssize_t i, nkwargs = PyTuple_GET_SIZE(kwnames);
+ PyObject *dict;
+ dict = PyDict_New();
+ if (unlikely(!dict))
+ return NULL;
+ for (i=0; i= 3
+ "%s() got multiple values for keyword argument '%U'", func_name, kw_name);
+ #else
+ "%s() got multiple values for keyword argument '%s'", func_name,
+ PyString_AsString(kw_name));
+ #endif
+}
+
+/* ParseKeywords */
+static int __Pyx_ParseOptionalKeywords(
+ PyObject *kwds,
+ PyObject *const *kwvalues,
+ PyObject **argnames[],
+ PyObject *kwds2,
+ PyObject *values[],
+ Py_ssize_t num_pos_args,
+ const char* function_name)
+{
+ PyObject *key = 0, *value = 0;
+ Py_ssize_t pos = 0;
+ PyObject*** name;
+ PyObject*** first_kw_arg = argnames + num_pos_args;
+ int kwds_is_tuple = CYTHON_METH_FASTCALL && likely(PyTuple_Check(kwds));
+ while (1) {
+ Py_XDECREF(key); key = NULL;
+ Py_XDECREF(value); value = NULL;
+ if (kwds_is_tuple) {
+ Py_ssize_t size;
+#if CYTHON_ASSUME_SAFE_MACROS
+ size = PyTuple_GET_SIZE(kwds);
+#else
+ size = PyTuple_Size(kwds);
+ if (size < 0) goto bad;
+#endif
+ if (pos >= size) break;
+#if CYTHON_AVOID_BORROWED_REFS
+ key = __Pyx_PySequence_ITEM(kwds, pos);
+ if (!key) goto bad;
+#elif CYTHON_ASSUME_SAFE_MACROS
+ key = PyTuple_GET_ITEM(kwds, pos);
+#else
+ key = PyTuple_GetItem(kwds, pos);
+ if (!key) goto bad;
+#endif
+ value = kwvalues[pos];
+ pos++;
+ }
+ else
+ {
+ if (!PyDict_Next(kwds, &pos, &key, &value)) break;
+#if CYTHON_AVOID_BORROWED_REFS
+ Py_INCREF(key);
+#endif
+ }
+ name = first_kw_arg;
+ while (*name && (**name != key)) name++;
+ if (*name) {
+ values[name-argnames] = value;
+#if CYTHON_AVOID_BORROWED_REFS
+ Py_INCREF(value);
+ Py_DECREF(key);
+#endif
+ key = NULL;
+ value = NULL;
+ continue;
+ }
+#if !CYTHON_AVOID_BORROWED_REFS
+ Py_INCREF(key);
+#endif
+ Py_INCREF(value);
+ name = first_kw_arg;
+ #if PY_MAJOR_VERSION < 3
+ if (likely(PyString_Check(key))) {
+ while (*name) {
+ if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key))
+ && _PyString_Eq(**name, key)) {
+ values[name-argnames] = value;
+#if CYTHON_AVOID_BORROWED_REFS
+ value = NULL;
+#endif
+ break;
+ }
+ name++;
+ }
+ if (*name) continue;
+ else {
+ PyObject*** argname = argnames;
+ while (argname != first_kw_arg) {
+ if ((**argname == key) || (
+ (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key))
+ && _PyString_Eq(**argname, key))) {
+ goto arg_passed_twice;
+ }
+ argname++;
+ }
+ }
+ } else
+ #endif
+ if (likely(PyUnicode_Check(key))) {
+ while (*name) {
+ int cmp = (
+ #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3
+ (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 :
+ #endif
+ PyUnicode_Compare(**name, key)
+ );
+ if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad;
+ if (cmp == 0) {
+ values[name-argnames] = value;
+#if CYTHON_AVOID_BORROWED_REFS
+ value = NULL;
+#endif
+ break;
+ }
+ name++;
+ }
+ if (*name) continue;
+ else {
+ PyObject*** argname = argnames;
+ while (argname != first_kw_arg) {
+ int cmp = (**argname == key) ? 0 :
+ #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3
+ (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 :
+ #endif
+ PyUnicode_Compare(**argname, key);
+ if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad;
+ if (cmp == 0) goto arg_passed_twice;
+ argname++;
+ }
+ }
+ } else
+ goto invalid_keyword_type;
+ if (kwds2) {
+ if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad;
+ } else {
+ goto invalid_keyword;
+ }
+ }
+ Py_XDECREF(key);
+ Py_XDECREF(value);
+ return 0;
+arg_passed_twice:
+ __Pyx_RaiseDoubleKeywordsError(function_name, key);
+ goto bad;
+invalid_keyword_type:
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() keywords must be strings", function_name);
+ goto bad;
+invalid_keyword:
+ #if PY_MAJOR_VERSION < 3
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() got an unexpected keyword argument '%.200s'",
+ function_name, PyString_AsString(key));
+ #else
+ PyErr_Format(PyExc_TypeError,
+ "%s() got an unexpected keyword argument '%U'",
+ function_name, key);
+ #endif
+bad:
+ Py_XDECREF(key);
+ Py_XDECREF(value);
+ return -1;
+}
+
+/* ArgTypeTest */
+static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact)
+{
+ __Pyx_TypeName type_name;
+ __Pyx_TypeName obj_type_name;
+ if (unlikely(!type)) {
+ PyErr_SetString(PyExc_SystemError, "Missing type object");
+ return 0;
+ }
+ else if (exact) {
+ #if PY_MAJOR_VERSION == 2
+ if ((type == &PyBaseString_Type) && likely(__Pyx_PyBaseString_CheckExact(obj))) return 1;
+ #endif
+ }
+ else {
+ if (likely(__Pyx_TypeCheck(obj, type))) return 1;
+ }
+ type_name = __Pyx_PyType_GetName(type);
+ obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj));
+ PyErr_Format(PyExc_TypeError,
+ "Argument '%.200s' has incorrect type (expected " __Pyx_FMT_TYPENAME
+ ", got " __Pyx_FMT_TYPENAME ")", name, type_name, obj_type_name);
+ __Pyx_DECREF_TypeName(type_name);
+ __Pyx_DECREF_TypeName(obj_type_name);
+ return 0;
+}
+
+/* IsLittleEndian */
+static CYTHON_INLINE int __Pyx_Is_Little_Endian(void)
+{
+ union {
+ uint32_t u32;
+ uint8_t u8[4];
+ } S;
+ S.u32 = 0x01020304;
+ return S.u8[0] == 4;
+}
+
+/* BufferFormatCheck */
+static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx,
+ __Pyx_BufFmt_StackElem* stack,
+ __Pyx_TypeInfo* type) {
+ stack[0].field = &ctx->root;
+ stack[0].parent_offset = 0;
+ ctx->root.type = type;
+ ctx->root.name = "buffer dtype";
+ ctx->root.offset = 0;
+ ctx->head = stack;
+ ctx->head->field = &ctx->root;
+ ctx->fmt_offset = 0;
+ ctx->head->parent_offset = 0;
+ ctx->new_packmode = '@';
+ ctx->enc_packmode = '@';
+ ctx->new_count = 1;
+ ctx->enc_count = 0;
+ ctx->enc_type = 0;
+ ctx->is_complex = 0;
+ ctx->is_valid_array = 0;
+ ctx->struct_alignment = 0;
+ while (type->typegroup == 'S') {
+ ++ctx->head;
+ ctx->head->field = type->fields;
+ ctx->head->parent_offset = 0;
+ type = type->fields->type;
+ }
+}
+static int __Pyx_BufFmt_ParseNumber(const char** ts) {
+ int count;
+ const char* t = *ts;
+ if (*t < '0' || *t > '9') {
+ return -1;
+ } else {
+ count = *t++ - '0';
+ while (*t >= '0' && *t <= '9') {
+ count *= 10;
+ count += *t++ - '0';
+ }
+ }
+ *ts = t;
+ return count;
+}
+static int __Pyx_BufFmt_ExpectNumber(const char **ts) {
+ int number = __Pyx_BufFmt_ParseNumber(ts);
+ if (number == -1)
+ PyErr_Format(PyExc_ValueError,\
+ "Does not understand character buffer dtype format string ('%c')", **ts);
+ return number;
+}
+static void __Pyx_BufFmt_RaiseUnexpectedChar(char ch) {
+ PyErr_Format(PyExc_ValueError,
+ "Unexpected format string character: '%c'", ch);
+}
+static const char* __Pyx_BufFmt_DescribeTypeChar(char ch, int is_complex) {
+ switch (ch) {
+ case '?': return "'bool'";
+ case 'c': return "'char'";
+ case 'b': return "'signed char'";
+ case 'B': return "'unsigned char'";
+ case 'h': return "'short'";
+ case 'H': return "'unsigned short'";
+ case 'i': return "'int'";
+ case 'I': return "'unsigned int'";
+ case 'l': return "'long'";
+ case 'L': return "'unsigned long'";
+ case 'q': return "'long long'";
+ case 'Q': return "'unsigned long long'";
+ case 'f': return (is_complex ? "'complex float'" : "'float'");
+ case 'd': return (is_complex ? "'complex double'" : "'double'");
+ case 'g': return (is_complex ? "'complex long double'" : "'long double'");
+ case 'T': return "a struct";
+ case 'O': return "Python object";
+ case 'P': return "a pointer";
+ case 's': case 'p': return "a string";
+ case 0: return "end";
+ default: return "unparsable format string";
+ }
+}
+static size_t __Pyx_BufFmt_TypeCharToStandardSize(char ch, int is_complex) {
+ switch (ch) {
+ case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1;
+ case 'h': case 'H': return 2;
+ case 'i': case 'I': case 'l': case 'L': return 4;
+ case 'q': case 'Q': return 8;
+ case 'f': return (is_complex ? 8 : 4);
+ case 'd': return (is_complex ? 16 : 8);
+ case 'g': {
+ PyErr_SetString(PyExc_ValueError, "Python does not define a standard format string size for long double ('g')..");
+ return 0;
+ }
+ case 'O': case 'P': return sizeof(void*);
+ default:
+ __Pyx_BufFmt_RaiseUnexpectedChar(ch);
+ return 0;
+ }
+}
+static size_t __Pyx_BufFmt_TypeCharToNativeSize(char ch, int is_complex) {
+ switch (ch) {
+ case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1;
+ case 'h': case 'H': return sizeof(short);
+ case 'i': case 'I': return sizeof(int);
+ case 'l': case 'L': return sizeof(long);
+ #ifdef HAVE_LONG_LONG
+ case 'q': case 'Q': return sizeof(PY_LONG_LONG);
+ #endif
+ case 'f': return sizeof(float) * (is_complex ? 2 : 1);
+ case 'd': return sizeof(double) * (is_complex ? 2 : 1);
+ case 'g': return sizeof(long double) * (is_complex ? 2 : 1);
+ case 'O': case 'P': return sizeof(void*);
+ default: {
+ __Pyx_BufFmt_RaiseUnexpectedChar(ch);
+ return 0;
+ }
+ }
+}
+typedef struct { char c; short x; } __Pyx_st_short;
+typedef struct { char c; int x; } __Pyx_st_int;
+typedef struct { char c; long x; } __Pyx_st_long;
+typedef struct { char c; float x; } __Pyx_st_float;
+typedef struct { char c; double x; } __Pyx_st_double;
+typedef struct { char c; long double x; } __Pyx_st_longdouble;
+typedef struct { char c; void *x; } __Pyx_st_void_p;
+#ifdef HAVE_LONG_LONG
+typedef struct { char c; PY_LONG_LONG x; } __Pyx_st_longlong;
+#endif
+static size_t __Pyx_BufFmt_TypeCharToAlignment(char ch, int is_complex) {
+ CYTHON_UNUSED_VAR(is_complex);
+ switch (ch) {
+ case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1;
+ case 'h': case 'H': return sizeof(__Pyx_st_short) - sizeof(short);
+ case 'i': case 'I': return sizeof(__Pyx_st_int) - sizeof(int);
+ case 'l': case 'L': return sizeof(__Pyx_st_long) - sizeof(long);
+#ifdef HAVE_LONG_LONG
+ case 'q': case 'Q': return sizeof(__Pyx_st_longlong) - sizeof(PY_LONG_LONG);
+#endif
+ case 'f': return sizeof(__Pyx_st_float) - sizeof(float);
+ case 'd': return sizeof(__Pyx_st_double) - sizeof(double);
+ case 'g': return sizeof(__Pyx_st_longdouble) - sizeof(long double);
+ case 'P': case 'O': return sizeof(__Pyx_st_void_p) - sizeof(void*);
+ default:
+ __Pyx_BufFmt_RaiseUnexpectedChar(ch);
+ return 0;
+ }
+}
+/* These are for computing the padding at the end of the struct to align
+ on the first member of the struct. This will probably the same as above,
+ but we don't have any guarantees.
+ */
+typedef struct { short x; char c; } __Pyx_pad_short;
+typedef struct { int x; char c; } __Pyx_pad_int;
+typedef struct { long x; char c; } __Pyx_pad_long;
+typedef struct { float x; char c; } __Pyx_pad_float;
+typedef struct { double x; char c; } __Pyx_pad_double;
+typedef struct { long double x; char c; } __Pyx_pad_longdouble;
+typedef struct { void *x; char c; } __Pyx_pad_void_p;
+#ifdef HAVE_LONG_LONG
+typedef struct { PY_LONG_LONG x; char c; } __Pyx_pad_longlong;
+#endif
+static size_t __Pyx_BufFmt_TypeCharToPadding(char ch, int is_complex) {
+ CYTHON_UNUSED_VAR(is_complex);
+ switch (ch) {
+ case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1;
+ case 'h': case 'H': return sizeof(__Pyx_pad_short) - sizeof(short);
+ case 'i': case 'I': return sizeof(__Pyx_pad_int) - sizeof(int);
+ case 'l': case 'L': return sizeof(__Pyx_pad_long) - sizeof(long);
+#ifdef HAVE_LONG_LONG
+ case 'q': case 'Q': return sizeof(__Pyx_pad_longlong) - sizeof(PY_LONG_LONG);
+#endif
+ case 'f': return sizeof(__Pyx_pad_float) - sizeof(float);
+ case 'd': return sizeof(__Pyx_pad_double) - sizeof(double);
+ case 'g': return sizeof(__Pyx_pad_longdouble) - sizeof(long double);
+ case 'P': case 'O': return sizeof(__Pyx_pad_void_p) - sizeof(void*);
+ default:
+ __Pyx_BufFmt_RaiseUnexpectedChar(ch);
+ return 0;
+ }
+}
+static char __Pyx_BufFmt_TypeCharToGroup(char ch, int is_complex) {
+ switch (ch) {
+ case 'c':
+ return 'H';
+ case 'b': case 'h': case 'i':
+ case 'l': case 'q': case 's': case 'p':
+ return 'I';
+ case '?': case 'B': case 'H': case 'I': case 'L': case 'Q':
+ return 'U';
+ case 'f': case 'd': case 'g':
+ return (is_complex ? 'C' : 'R');
+ case 'O':
+ return 'O';
+ case 'P':
+ return 'P';
+ default: {
+ __Pyx_BufFmt_RaiseUnexpectedChar(ch);
+ return 0;
+ }
+ }
+}
+static void __Pyx_BufFmt_RaiseExpected(__Pyx_BufFmt_Context* ctx) {
+ if (ctx->head == NULL || ctx->head->field == &ctx->root) {
+ const char* expected;
+ const char* quote;
+ if (ctx->head == NULL) {
+ expected = "end";
+ quote = "";
+ } else {
+ expected = ctx->head->field->type->name;
+ quote = "'";
+ }
+ PyErr_Format(PyExc_ValueError,
+ "Buffer dtype mismatch, expected %s%s%s but got %s",
+ quote, expected, quote,
+ __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex));
+ } else {
+ __Pyx_StructField* field = ctx->head->field;
+ __Pyx_StructField* parent = (ctx->head - 1)->field;
+ PyErr_Format(PyExc_ValueError,
+ "Buffer dtype mismatch, expected '%s' but got %s in '%s.%s'",
+ field->type->name, __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex),
+ parent->type->name, field->name);
+ }
+}
+static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) {
+ char group;
+ size_t size, offset, arraysize = 1;
+ if (ctx->enc_type == 0) return 0;
+ if (ctx->head->field->type->arraysize[0]) {
+ int i, ndim = 0;
+ if (ctx->enc_type == 's' || ctx->enc_type == 'p') {
+ ctx->is_valid_array = ctx->head->field->type->ndim == 1;
+ ndim = 1;
+ if (ctx->enc_count != ctx->head->field->type->arraysize[0]) {
+ PyErr_Format(PyExc_ValueError,
+ "Expected a dimension of size %zu, got %zu",
+ ctx->head->field->type->arraysize[0], ctx->enc_count);
+ return -1;
+ }
+ }
+ if (!ctx->is_valid_array) {
+ PyErr_Format(PyExc_ValueError, "Expected %d dimensions, got %d",
+ ctx->head->field->type->ndim, ndim);
+ return -1;
+ }
+ for (i = 0; i < ctx->head->field->type->ndim; i++) {
+ arraysize *= ctx->head->field->type->arraysize[i];
+ }
+ ctx->is_valid_array = 0;
+ ctx->enc_count = 1;
+ }
+ group = __Pyx_BufFmt_TypeCharToGroup(ctx->enc_type, ctx->is_complex);
+ do {
+ __Pyx_StructField* field = ctx->head->field;
+ __Pyx_TypeInfo* type = field->type;
+ if (ctx->enc_packmode == '@' || ctx->enc_packmode == '^') {
+ size = __Pyx_BufFmt_TypeCharToNativeSize(ctx->enc_type, ctx->is_complex);
+ } else {
+ size = __Pyx_BufFmt_TypeCharToStandardSize(ctx->enc_type, ctx->is_complex);
+ }
+ if (ctx->enc_packmode == '@') {
+ size_t align_at = __Pyx_BufFmt_TypeCharToAlignment(ctx->enc_type, ctx->is_complex);
+ size_t align_mod_offset;
+ if (align_at == 0) return -1;
+ align_mod_offset = ctx->fmt_offset % align_at;
+ if (align_mod_offset > 0) ctx->fmt_offset += align_at - align_mod_offset;
+ if (ctx->struct_alignment == 0)
+ ctx->struct_alignment = __Pyx_BufFmt_TypeCharToPadding(ctx->enc_type,
+ ctx->is_complex);
+ }
+ if (type->size != size || type->typegroup != group) {
+ if (type->typegroup == 'C' && type->fields != NULL) {
+ size_t parent_offset = ctx->head->parent_offset + field->offset;
+ ++ctx->head;
+ ctx->head->field = type->fields;
+ ctx->head->parent_offset = parent_offset;
+ continue;
+ }
+ if ((type->typegroup == 'H' || group == 'H') && type->size == size) {
+ } else {
+ __Pyx_BufFmt_RaiseExpected(ctx);
+ return -1;
+ }
+ }
+ offset = ctx->head->parent_offset + field->offset;
+ if (ctx->fmt_offset != offset) {
+ PyErr_Format(PyExc_ValueError,
+ "Buffer dtype mismatch; next field is at offset %" CYTHON_FORMAT_SSIZE_T "d but %" CYTHON_FORMAT_SSIZE_T "d expected",
+ (Py_ssize_t)ctx->fmt_offset, (Py_ssize_t)offset);
+ return -1;
+ }
+ ctx->fmt_offset += size;
+ if (arraysize)
+ ctx->fmt_offset += (arraysize - 1) * size;
+ --ctx->enc_count;
+ while (1) {
+ if (field == &ctx->root) {
+ ctx->head = NULL;
+ if (ctx->enc_count != 0) {
+ __Pyx_BufFmt_RaiseExpected(ctx);
+ return -1;
+ }
+ break;
+ }
+ ctx->head->field = ++field;
+ if (field->type == NULL) {
+ --ctx->head;
+ field = ctx->head->field;
+ continue;
+ } else if (field->type->typegroup == 'S') {
+ size_t parent_offset = ctx->head->parent_offset + field->offset;
+ if (field->type->fields->type == NULL) continue;
+ field = field->type->fields;
+ ++ctx->head;
+ ctx->head->field = field;
+ ctx->head->parent_offset = parent_offset;
+ break;
+ } else {
+ break;
+ }
+ }
+ } while (ctx->enc_count);
+ ctx->enc_type = 0;
+ ctx->is_complex = 0;
+ return 0;
+}
+static int
+__pyx_buffmt_parse_array(__Pyx_BufFmt_Context* ctx, const char** tsp)
+{
+ const char *ts = *tsp;
+ int i = 0, number, ndim;
+ ++ts;
+ if (ctx->new_count != 1) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot handle repeated arrays in format string");
+ return -1;
+ }
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return -1;
+ ndim = ctx->head->field->type->ndim;
+ while (*ts && *ts != ')') {
+ switch (*ts) {
+ case ' ': case '\f': case '\r': case '\n': case '\t': case '\v': continue;
+ default: break;
+ }
+ number = __Pyx_BufFmt_ExpectNumber(&ts);
+ if (number == -1) return -1;
+ if (i < ndim && (size_t) number != ctx->head->field->type->arraysize[i]) {
+ PyErr_Format(PyExc_ValueError,
+ "Expected a dimension of size %zu, got %d",
+ ctx->head->field->type->arraysize[i], number);
+ return -1;
+ }
+ if (*ts != ',' && *ts != ')') {
+ PyErr_Format(PyExc_ValueError,
+ "Expected a comma in format string, got '%c'", *ts);
+ return -1;
+ }
+ if (*ts == ',') ts++;
+ i++;
+ }
+ if (i != ndim) {
+ PyErr_Format(PyExc_ValueError, "Expected %d dimension(s), got %d",
+ ctx->head->field->type->ndim, i);
+ return -1;
+ }
+ if (!*ts) {
+ PyErr_SetString(PyExc_ValueError,
+ "Unexpected end of format string, expected ')'");
+ return -1;
+ }
+ ctx->is_valid_array = 1;
+ ctx->new_count = 1;
+ *tsp = ++ts;
+ return 0;
+}
+static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts) {
+ int got_Z = 0;
+ while (1) {
+ switch(*ts) {
+ case 0:
+ if (ctx->enc_type != 0 && ctx->head == NULL) {
+ __Pyx_BufFmt_RaiseExpected(ctx);
+ return NULL;
+ }
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
+ if (ctx->head != NULL) {
+ __Pyx_BufFmt_RaiseExpected(ctx);
+ return NULL;
+ }
+ return ts;
+ case ' ':
+ case '\r':
+ case '\n':
+ ++ts;
+ break;
+ case '<':
+ if (!__Pyx_Is_Little_Endian()) {
+ PyErr_SetString(PyExc_ValueError, "Little-endian buffer not supported on big-endian compiler");
+ return NULL;
+ }
+ ctx->new_packmode = '=';
+ ++ts;
+ break;
+ case '>':
+ case '!':
+ if (__Pyx_Is_Little_Endian()) {
+ PyErr_SetString(PyExc_ValueError, "Big-endian buffer not supported on little-endian compiler");
+ return NULL;
+ }
+ ctx->new_packmode = '=';
+ ++ts;
+ break;
+ case '=':
+ case '@':
+ case '^':
+ ctx->new_packmode = *ts++;
+ break;
+ case 'T':
+ {
+ const char* ts_after_sub;
+ size_t i, struct_count = ctx->new_count;
+ size_t struct_alignment = ctx->struct_alignment;
+ ctx->new_count = 1;
+ ++ts;
+ if (*ts != '{') {
+ PyErr_SetString(PyExc_ValueError, "Buffer acquisition: Expected '{' after 'T'");
+ return NULL;
+ }
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
+ ctx->enc_type = 0;
+ ctx->enc_count = 0;
+ ctx->struct_alignment = 0;
+ ++ts;
+ ts_after_sub = ts;
+ for (i = 0; i != struct_count; ++i) {
+ ts_after_sub = __Pyx_BufFmt_CheckString(ctx, ts);
+ if (!ts_after_sub) return NULL;
+ }
+ ts = ts_after_sub;
+ if (struct_alignment) ctx->struct_alignment = struct_alignment;
+ }
+ break;
+ case '}':
+ {
+ size_t alignment = ctx->struct_alignment;
+ ++ts;
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
+ ctx->enc_type = 0;
+ if (alignment && ctx->fmt_offset % alignment) {
+ ctx->fmt_offset += alignment - (ctx->fmt_offset % alignment);
+ }
+ }
+ return ts;
+ case 'x':
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
+ ctx->fmt_offset += ctx->new_count;
+ ctx->new_count = 1;
+ ctx->enc_count = 0;
+ ctx->enc_type = 0;
+ ctx->enc_packmode = ctx->new_packmode;
+ ++ts;
+ break;
+ case 'Z':
+ got_Z = 1;
+ ++ts;
+ if (*ts != 'f' && *ts != 'd' && *ts != 'g') {
+ __Pyx_BufFmt_RaiseUnexpectedChar('Z');
+ return NULL;
+ }
+ CYTHON_FALLTHROUGH;
+ case '?': case 'c': case 'b': case 'B': case 'h': case 'H': case 'i': case 'I':
+ case 'l': case 'L': case 'q': case 'Q':
+ case 'f': case 'd': case 'g':
+ case 'O': case 'p':
+ if ((ctx->enc_type == *ts) && (got_Z == ctx->is_complex) &&
+ (ctx->enc_packmode == ctx->new_packmode) && (!ctx->is_valid_array)) {
+ ctx->enc_count += ctx->new_count;
+ ctx->new_count = 1;
+ got_Z = 0;
+ ++ts;
+ break;
+ }
+ CYTHON_FALLTHROUGH;
+ case 's':
+ if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
+ ctx->enc_count = ctx->new_count;
+ ctx->enc_packmode = ctx->new_packmode;
+ ctx->enc_type = *ts;
+ ctx->is_complex = got_Z;
+ ++ts;
+ ctx->new_count = 1;
+ got_Z = 0;
+ break;
+ case ':':
+ ++ts;
+ while(*ts != ':') ++ts;
+ ++ts;
+ break;
+ case '(':
+ if (__pyx_buffmt_parse_array(ctx, &ts) < 0) return NULL;
+ break;
+ default:
+ {
+ int number = __Pyx_BufFmt_ExpectNumber(&ts);
+ if (number == -1) return NULL;
+ ctx->new_count = (size_t)number;
+ }
+ }
+ }
+}
+
+/* BufferGetAndValidate */
+ static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
+ if (unlikely(info->buf == NULL)) return;
+ if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
+ __Pyx_ReleaseBuffer(info);
+}
+static void __Pyx_ZeroBuffer(Py_buffer* buf) {
+ buf->buf = NULL;
+ buf->obj = NULL;
+ buf->strides = __Pyx_zeros;
+ buf->shape = __Pyx_zeros;
+ buf->suboffsets = __Pyx_minusones;
+}
+static int __Pyx__GetBufferAndValidate(
+ Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags,
+ int nd, int cast, __Pyx_BufFmt_StackElem* stack)
+{
+ buf->buf = NULL;
+ if (unlikely(__Pyx_GetBuffer(obj, buf, flags) == -1)) {
+ __Pyx_ZeroBuffer(buf);
+ return -1;
+ }
+ if (unlikely(buf->ndim != nd)) {
+ PyErr_Format(PyExc_ValueError,
+ "Buffer has wrong number of dimensions (expected %d, got %d)",
+ nd, buf->ndim);
+ goto fail;
+ }
+ if (!cast) {
+ __Pyx_BufFmt_Context ctx;
+ __Pyx_BufFmt_Init(&ctx, stack, dtype);
+ if (!__Pyx_BufFmt_CheckString(&ctx, buf->format)) goto fail;
+ }
+ if (unlikely((size_t)buf->itemsize != dtype->size)) {
+ PyErr_Format(PyExc_ValueError,
+ "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "d byte%s) does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "d byte%s)",
+ buf->itemsize, (buf->itemsize > 1) ? "s" : "",
+ dtype->name, (Py_ssize_t)dtype->size, (dtype->size > 1) ? "s" : "");
+ goto fail;
+ }
+ if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
+ return 0;
+fail:;
+ __Pyx_SafeReleaseBuffer(buf);
+ return -1;
+}
+
+/* GetItemInt */
+ static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) {
+ PyObject *r;
+ if (unlikely(!j)) return NULL;
+ r = PyObject_GetItem(o, j);
+ Py_DECREF(j);
+ return r;
+}
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i,
+ CYTHON_NCP_UNUSED int wraparound,
+ CYTHON_NCP_UNUSED int boundscheck) {
+#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ Py_ssize_t wrapped_i = i;
+ if (wraparound & unlikely(i < 0)) {
+ wrapped_i += PyList_GET_SIZE(o);
+ }
+ if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) {
+ PyObject *r = PyList_GET_ITEM(o, wrapped_i);
+ Py_INCREF(r);
+ return r;
+ }
+ return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i));
+#else
+ return PySequence_GetItem(o, i);
+#endif
+}
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i,
+ CYTHON_NCP_UNUSED int wraparound,
+ CYTHON_NCP_UNUSED int boundscheck) {
+#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ Py_ssize_t wrapped_i = i;
+ if (wraparound & unlikely(i < 0)) {
+ wrapped_i += PyTuple_GET_SIZE(o);
+ }
+ if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) {
+ PyObject *r = PyTuple_GET_ITEM(o, wrapped_i);
+ Py_INCREF(r);
+ return r;
+ }
+ return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i));
+#else
+ return PySequence_GetItem(o, i);
+#endif
+}
+static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list,
+ CYTHON_NCP_UNUSED int wraparound,
+ CYTHON_NCP_UNUSED int boundscheck) {
+#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS
+ if (is_list || PyList_CheckExact(o)) {
+ Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o);
+ if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) {
+ PyObject *r = PyList_GET_ITEM(o, n);
+ Py_INCREF(r);
+ return r;
+ }
+ }
+ else if (PyTuple_CheckExact(o)) {
+ Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o);
+ if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) {
+ PyObject *r = PyTuple_GET_ITEM(o, n);
+ Py_INCREF(r);
+ return r;
+ }
+ } else {
+ PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping;
+ PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence;
+ if (mm && mm->mp_subscript) {
+ PyObject *r, *key = PyInt_FromSsize_t(i);
+ if (unlikely(!key)) return NULL;
+ r = mm->mp_subscript(o, key);
+ Py_DECREF(key);
+ return r;
+ }
+ if (likely(sm && sm->sq_item)) {
+ if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) {
+ Py_ssize_t l = sm->sq_length(o);
+ if (likely(l >= 0)) {
+ i += l;
+ } else {
+ if (!PyErr_ExceptionMatches(PyExc_OverflowError))
+ return NULL;
+ PyErr_Clear();
+ }
+ }
+ return sm->sq_item(o, i);
+ }
+ }
+#else
+ if (is_list || !PyMapping_Check(o)) {
+ return PySequence_GetItem(o, i);
+ }
+#endif
+ return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i));
+}
+
+/* PyFunctionFastCall */
+ #if CYTHON_FAST_PYCALL && !CYTHON_VECTORCALL
+static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na,
+ PyObject *globals) {
+ PyFrameObject *f;
+ PyThreadState *tstate = __Pyx_PyThreadState_Current;
+ PyObject **fastlocals;
+ Py_ssize_t i;
+ PyObject *result;
+ assert(globals != NULL);
+ /* XXX Perhaps we should create a specialized
+ PyFrame_New() that doesn't take locals, but does
+ take builtins without sanity checking them.
+ */
+ assert(tstate != NULL);
+ f = PyFrame_New(tstate, co, globals, NULL);
+ if (f == NULL) {
+ return NULL;
+ }
+ fastlocals = __Pyx_PyFrame_GetLocalsplus(f);
+ for (i = 0; i < na; i++) {
+ Py_INCREF(*args);
+ fastlocals[i] = *args++;
+ }
+ result = PyEval_EvalFrameEx(f,0);
+ ++tstate->recursion_depth;
+ Py_DECREF(f);
+ --tstate->recursion_depth;
+ return result;
+}
+static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) {
+ PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func);
+ PyObject *globals = PyFunction_GET_GLOBALS(func);
+ PyObject *argdefs = PyFunction_GET_DEFAULTS(func);
+ PyObject *closure;
+#if PY_MAJOR_VERSION >= 3
+ PyObject *kwdefs;
+#endif
+ PyObject *kwtuple, **k;
+ PyObject **d;
+ Py_ssize_t nd;
+ Py_ssize_t nk;
+ PyObject *result;
+ assert(kwargs == NULL || PyDict_Check(kwargs));
+ nk = kwargs ? PyDict_Size(kwargs) : 0;
+ #if PY_MAJOR_VERSION < 3
+ if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) {
+ return NULL;
+ }
+ #else
+ if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) {
+ return NULL;
+ }
+ #endif
+ if (
+#if PY_MAJOR_VERSION >= 3
+ co->co_kwonlyargcount == 0 &&
+#endif
+ likely(kwargs == NULL || nk == 0) &&
+ co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) {
+ if (argdefs == NULL && co->co_argcount == nargs) {
+ result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals);
+ goto done;
+ }
+ else if (nargs == 0 && argdefs != NULL
+ && co->co_argcount == Py_SIZE(argdefs)) {
+ /* function called with no arguments, but all parameters have
+ a default value: use default values as arguments .*/
+ args = &PyTuple_GET_ITEM(argdefs, 0);
+ result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals);
+ goto done;
+ }
+ }
+ if (kwargs != NULL) {
+ Py_ssize_t pos, i;
+ kwtuple = PyTuple_New(2 * nk);
+ if (kwtuple == NULL) {
+ result = NULL;
+ goto done;
+ }
+ k = &PyTuple_GET_ITEM(kwtuple, 0);
+ pos = i = 0;
+ while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) {
+ Py_INCREF(k[i]);
+ Py_INCREF(k[i+1]);
+ i += 2;
+ }
+ nk = i / 2;
+ }
+ else {
+ kwtuple = NULL;
+ k = NULL;
+ }
+ closure = PyFunction_GET_CLOSURE(func);
+#if PY_MAJOR_VERSION >= 3
+ kwdefs = PyFunction_GET_KW_DEFAULTS(func);
+#endif
+ if (argdefs != NULL) {
+ d = &PyTuple_GET_ITEM(argdefs, 0);
+ nd = Py_SIZE(argdefs);
+ }
+ else {
+ d = NULL;
+ nd = 0;
+ }
+#if PY_MAJOR_VERSION >= 3
+ result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL,
+ args, (int)nargs,
+ k, (int)nk,
+ d, (int)nd, kwdefs, closure);
+#else
+ result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL,
+ args, (int)nargs,
+ k, (int)nk,
+ d, (int)nd, closure);
+#endif
+ Py_XDECREF(kwtuple);
+done:
+ Py_LeaveRecursiveCall();
+ return result;
+}
+#endif
+
+/* PyObjectCallMethO */
+ #if CYTHON_COMPILING_IN_CPYTHON
+static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) {
+ PyObject *self, *result;
+ PyCFunction cfunc;
+ cfunc = __Pyx_CyOrPyCFunction_GET_FUNCTION(func);
+ self = __Pyx_CyOrPyCFunction_GET_SELF(func);
+ #if PY_MAJOR_VERSION < 3
+ if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object")))
+ return NULL;
+ #else
+ if (unlikely(Py_EnterRecursiveCall(" while calling a Python object")))
+ return NULL;
+ #endif
+ result = cfunc(self, arg);
+ Py_LeaveRecursiveCall();
+ if (unlikely(!result) && unlikely(!PyErr_Occurred())) {
+ PyErr_SetString(
+ PyExc_SystemError,
+ "NULL result without error in PyObject_Call");
+ }
+ return result;
+}
+#endif
+
+/* PyObjectFastCall */
+ #if PY_VERSION_HEX < 0x03090000 || CYTHON_COMPILING_IN_LIMITED_API
+static PyObject* __Pyx_PyObject_FastCall_fallback(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs) {
+ PyObject *argstuple;
+ PyObject *result = 0;
+ size_t i;
+ argstuple = PyTuple_New((Py_ssize_t)nargs);
+ if (unlikely(!argstuple)) return NULL;
+ for (i = 0; i < nargs; i++) {
+ Py_INCREF(args[i]);
+ if (__Pyx_PyTuple_SET_ITEM(argstuple, (Py_ssize_t)i, args[i]) < 0) goto bad;
+ }
+ result = __Pyx_PyObject_Call(func, argstuple, kwargs);
+ bad:
+ Py_DECREF(argstuple);
+ return result;
+}
+#endif
+static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t _nargs, PyObject *kwargs) {
+ Py_ssize_t nargs = __Pyx_PyVectorcall_NARGS(_nargs);
+#if CYTHON_COMPILING_IN_CPYTHON
+ if (nargs == 0 && kwargs == NULL) {
+ if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_NOARGS))
+ return __Pyx_PyObject_CallMethO(func, NULL);
+ }
+ else if (nargs == 1 && kwargs == NULL) {
+ if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_O))
+ return __Pyx_PyObject_CallMethO(func, args[0]);
+ }
+#endif
+ #if PY_VERSION_HEX < 0x030800B1
+ #if CYTHON_FAST_PYCCALL
+ if (PyCFunction_Check(func)) {
+ if (kwargs) {
+ return _PyCFunction_FastCallDict(func, args, nargs, kwargs);
+ } else {
+ return _PyCFunction_FastCallKeywords(func, args, nargs, NULL);
+ }
+ }
+ #if PY_VERSION_HEX >= 0x030700A1
+ if (!kwargs && __Pyx_IS_TYPE(func, &PyMethodDescr_Type)) {
+ return _PyMethodDescr_FastCallKeywords(func, args, nargs, NULL);
+ }
+ #endif
+ #endif
+ #if CYTHON_FAST_PYCALL
+ if (PyFunction_Check(func)) {
+ return __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs);
+ }
+ #endif
+ #endif
+ if (kwargs == NULL) {
+ #if CYTHON_VECTORCALL
+ #if PY_VERSION_HEX < 0x03090000
+ vectorcallfunc f = _PyVectorcall_Function(func);
+ #else
+ vectorcallfunc f = PyVectorcall_Function(func);
+ #endif
+ if (f) {
+ return f(func, args, (size_t)nargs, NULL);
+ }
+ #elif defined(__Pyx_CyFunction_USED) && CYTHON_BACKPORT_VECTORCALL
+ if (__Pyx_CyFunction_CheckExact(func)) {
+ __pyx_vectorcallfunc f = __Pyx_CyFunction_func_vectorcall(func);
+ if (f) return f(func, args, (size_t)nargs, NULL);
+ }
+ #endif
+ }
+ if (nargs == 0) {
+ return __Pyx_PyObject_Call(func, __pyx_empty_tuple, kwargs);
+ }
+ #if PY_VERSION_HEX >= 0x03090000 && !CYTHON_COMPILING_IN_LIMITED_API
+ return PyObject_VectorcallDict(func, args, (size_t)nargs, kwargs);
+ #else
+ return __Pyx_PyObject_FastCall_fallback(func, args, (size_t)nargs, kwargs);
+ #endif
+}
+
+/* PyObjectCallOneArg */
+ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) {
+ PyObject *args[2] = {NULL, arg};
+ return __Pyx_PyObject_FastCall(func, args+1, 1 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET);
+}
+
+/* ObjectGetItem */
+ #if CYTHON_USE_TYPE_SLOTS
+static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) {
+ PyObject *runerr = NULL;
+ Py_ssize_t key_value;
+ key_value = __Pyx_PyIndex_AsSsize_t(index);
+ if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) {
+ return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1);
+ }
+ if (PyErr_GivenExceptionMatches(runerr, PyExc_OverflowError)) {
+ __Pyx_TypeName index_type_name = __Pyx_PyType_GetName(Py_TYPE(index));
+ PyErr_Clear();
+ PyErr_Format(PyExc_IndexError,
+ "cannot fit '" __Pyx_FMT_TYPENAME "' into an index-sized integer", index_type_name);
+ __Pyx_DECREF_TypeName(index_type_name);
+ }
+ return NULL;
+}
+static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) {
+ __Pyx_TypeName obj_type_name;
+ if (likely(PyType_Check(obj))) {
+ PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, __pyx_n_s_class_getitem);
+ if (!meth) {
+ PyErr_Clear();
+ } else {
+ PyObject *result = __Pyx_PyObject_CallOneArg(meth, key);
+ Py_DECREF(meth);
+ return result;
+ }
+ }
+ obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj));
+ PyErr_Format(PyExc_TypeError,
+ "'" __Pyx_FMT_TYPENAME "' object is not subscriptable", obj_type_name);
+ __Pyx_DECREF_TypeName(obj_type_name);
+ return NULL;
+}
+static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) {
+ PyTypeObject *tp = Py_TYPE(obj);
+ PyMappingMethods *mm = tp->tp_as_mapping;
+ PySequenceMethods *sm = tp->tp_as_sequence;
+ if (likely(mm && mm->mp_subscript)) {
+ return mm->mp_subscript(obj, key);
+ }
+ if (likely(sm && sm->sq_item)) {
+ return __Pyx_PyObject_GetIndex(obj, key);
+ }
+ return __Pyx_PyObject_GetItem_Slow(obj, key);
+}
+#endif
+
+/* ExtTypeTest */
+ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
+ __Pyx_TypeName obj_type_name;
+ __Pyx_TypeName type_name;
+ if (unlikely(!type)) {
+ PyErr_SetString(PyExc_SystemError, "Missing type object");
+ return 0;
+ }
+ if (likely(__Pyx_TypeCheck(obj, type)))
+ return 1;
+ obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj));
+ type_name = __Pyx_PyType_GetName(type);
+ PyErr_Format(PyExc_TypeError,
+ "Cannot convert " __Pyx_FMT_TYPENAME " to " __Pyx_FMT_TYPENAME,
+ obj_type_name, type_name);
+ __Pyx_DECREF_TypeName(obj_type_name);
+ __Pyx_DECREF_TypeName(type_name);
+ return 0;
+}
+
+/* PyIntBinop */
+ #if !CYTHON_COMPILING_IN_PYPY
+static PyObject* __Pyx_PyInt_AddObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check) {
+ CYTHON_MAYBE_UNUSED_VAR(intval);
+ CYTHON_MAYBE_UNUSED_VAR(inplace);
+ CYTHON_UNUSED_VAR(zerodivision_check);
+ #if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_CheckExact(op1))) {
+ const long b = intval;
+ long x;
+ long a = PyInt_AS_LONG(op1);
+
+ x = (long)((unsigned long)a + (unsigned long)b);
+ if (likely((x^a) >= 0 || (x^b) >= 0))
+ return PyInt_FromLong(x);
+ return PyLong_Type.tp_as_number->nb_add(op1, op2);
+ }
+ #endif
+ #if CYTHON_USE_PYLONG_INTERNALS
+ if (likely(PyLong_CheckExact(op1))) {
+ const long b = intval;
+ long a, x;
+#ifdef HAVE_LONG_LONG
+ const PY_LONG_LONG llb = intval;
+ PY_LONG_LONG lla, llx;
+#endif
+ if (unlikely(__Pyx_PyLong_IsZero(op1))) {
+ return __Pyx_NewRef(op2);
+ }
+ if (likely(__Pyx_PyLong_IsCompact(op1))) {
+ a = __Pyx_PyLong_CompactValue(op1);
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(op1);
+ const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(op1);
+ switch (size) {
+ case -2:
+ if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) {
+ a = -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) {
+ lla = -(PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ case 2:
+ if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) {
+ a = (long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) {
+ lla = (PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ case -3:
+ if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) {
+ a = -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) {
+ lla = -(PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ case 3:
+ if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) {
+ a = (long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) {
+ lla = (PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ case -4:
+ if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) {
+ a = -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) {
+ lla = -(PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ case 4:
+ if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) {
+ a = (long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]));
+ break;
+ #ifdef HAVE_LONG_LONG
+ } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) {
+ lla = (PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0]));
+ goto long_long;
+ #endif
+ }
+ CYTHON_FALLTHROUGH;
+ default: return PyLong_Type.tp_as_number->nb_add(op1, op2);
+ }
+ }
+ x = a + b;
+ return PyLong_FromLong(x);
+#ifdef HAVE_LONG_LONG
+ long_long:
+ llx = lla + llb;
+ return PyLong_FromLongLong(llx);
+#endif
+
+
+ }
+ #endif
+ if (PyFloat_CheckExact(op1)) {
+ const long b = intval;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ double a = __pyx_PyFloat_AsDouble(op1);
+#else
+ double a = PyFloat_AS_DOUBLE(op1);
+#endif
+ double result;
+
+ PyFPE_START_PROTECT("add", return NULL)
+ result = ((double)a) + (double)b;
+ PyFPE_END_PROTECT(result)
+ return PyFloat_FromDouble(result);
+ }
+ return (inplace ? PyNumber_InPlaceAdd : PyNumber_Add)(op1, op2);
+}
+#endif
+
+/* PyDictVersioning */
+ #if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) {
+ PyObject *dict = Py_TYPE(obj)->tp_dict;
+ return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0;
+}
+static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) {
+ PyObject **dictptr = NULL;
+ Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset;
+ if (offset) {
+#if CYTHON_COMPILING_IN_CPYTHON
+ dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj);
+#else
+ dictptr = _PyObject_GetDictPtr(obj);
+#endif
+ }
+ return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0;
+}
+static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) {
+ PyObject *dict = Py_TYPE(obj)->tp_dict;
+ if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict)))
+ return 0;
+ return obj_dict_version == __Pyx_get_object_dict_version(obj);
+}
+#endif
+
+/* GetModuleGlobalName */
+ #if CYTHON_USE_DICT_VERSIONS
+static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value)
+#else
+static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name)
+#endif
+{
+ PyObject *result;
+#if !CYTHON_AVOID_BORROWED_REFS
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && PY_VERSION_HEX < 0x030d0000
+ result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash);
+ __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version)
+ if (likely(result)) {
+ return __Pyx_NewRef(result);
+ } else if (unlikely(PyErr_Occurred())) {
+ return NULL;
+ }
+#elif CYTHON_COMPILING_IN_LIMITED_API
+ if (unlikely(!__pyx_m)) {
+ return NULL;
+ }
+ result = PyObject_GetAttr(__pyx_m, name);
+ if (likely(result)) {
+ return result;
+ }
+#else
+ result = PyDict_GetItem(__pyx_d, name);
+ __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version)
+ if (likely(result)) {
+ return __Pyx_NewRef(result);
+ }
+#endif
+#else
+ result = PyObject_GetItem(__pyx_d, name);
+ __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version)
+ if (likely(result)) {
+ return __Pyx_NewRef(result);
+ }
+ PyErr_Clear();
+#endif
+ return __Pyx_GetBuiltinName(name);
+}
+
+/* BufferIndexError */
+ static void __Pyx_RaiseBufferIndexError(int axis) {
+ PyErr_Format(PyExc_IndexError,
+ "Out of bounds on buffer access (axis %d)", axis);
+}
+
+/* TypeImport */
+ #ifndef __PYX_HAVE_RT_ImportType_3_0_12
+#define __PYX_HAVE_RT_ImportType_3_0_12
+static PyTypeObject *__Pyx_ImportType_3_0_12(PyObject *module, const char *module_name, const char *class_name,
+ size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_12 check_size)
+{
+ PyObject *result = 0;
+ char warning[200];
+ Py_ssize_t basicsize;
+ Py_ssize_t itemsize;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ PyObject *py_basicsize;
+ PyObject *py_itemsize;
+#endif
+ result = PyObject_GetAttrString(module, class_name);
+ if (!result)
+ goto bad;
+ if (!PyType_Check(result)) {
+ PyErr_Format(PyExc_TypeError,
+ "%.200s.%.200s is not a type object",
+ module_name, class_name);
+ goto bad;
+ }
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ basicsize = ((PyTypeObject *)result)->tp_basicsize;
+ itemsize = ((PyTypeObject *)result)->tp_itemsize;
+#else
+ py_basicsize = PyObject_GetAttrString(result, "__basicsize__");
+ if (!py_basicsize)
+ goto bad;
+ basicsize = PyLong_AsSsize_t(py_basicsize);
+ Py_DECREF(py_basicsize);
+ py_basicsize = 0;
+ if (basicsize == (Py_ssize_t)-1 && PyErr_Occurred())
+ goto bad;
+ py_itemsize = PyObject_GetAttrString(result, "__itemsize__");
+ if (!py_itemsize)
+ goto bad;
+ itemsize = PyLong_AsSsize_t(py_itemsize);
+ Py_DECREF(py_itemsize);
+ py_itemsize = 0;
+ if (itemsize == (Py_ssize_t)-1 && PyErr_Occurred())
+ goto bad;
+#endif
+ if (itemsize) {
+ if (size % alignment) {
+ alignment = size % alignment;
+ }
+ if (itemsize < (Py_ssize_t)alignment)
+ itemsize = (Py_ssize_t)alignment;
+ }
+ if ((size_t)(basicsize + itemsize) < size) {
+ PyErr_Format(PyExc_ValueError,
+ "%.200s.%.200s size changed, may indicate binary incompatibility. "
+ "Expected %zd from C header, got %zd from PyObject",
+ module_name, class_name, size, basicsize+itemsize);
+ goto bad;
+ }
+ if (check_size == __Pyx_ImportType_CheckSize_Error_3_0_12 &&
+ ((size_t)basicsize > size || (size_t)(basicsize + itemsize) < size)) {
+ PyErr_Format(PyExc_ValueError,
+ "%.200s.%.200s size changed, may indicate binary incompatibility. "
+ "Expected %zd from C header, got %zd-%zd from PyObject",
+ module_name, class_name, size, basicsize, basicsize+itemsize);
+ goto bad;
+ }
+ else if (check_size == __Pyx_ImportType_CheckSize_Warn_3_0_12 && (size_t)basicsize > size) {
+ PyOS_snprintf(warning, sizeof(warning),
+ "%s.%s size changed, may indicate binary incompatibility. "
+ "Expected %zd from C header, got %zd from PyObject",
+ module_name, class_name, size, basicsize);
+ if (PyErr_WarnEx(NULL, warning, 0) < 0) goto bad;
+ }
+ return (PyTypeObject *)result;
+bad:
+ Py_XDECREF(result);
+ return NULL;
+}
+#endif
+
+/* Import */
+ static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) {
+ PyObject *module = 0;
+ PyObject *empty_dict = 0;
+ PyObject *empty_list = 0;
+ #if PY_MAJOR_VERSION < 3
+ PyObject *py_import;
+ py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import);
+ if (unlikely(!py_import))
+ goto bad;
+ if (!from_list) {
+ empty_list = PyList_New(0);
+ if (unlikely(!empty_list))
+ goto bad;
+ from_list = empty_list;
+ }
+ #endif
+ empty_dict = PyDict_New();
+ if (unlikely(!empty_dict))
+ goto bad;
+ {
+ #if PY_MAJOR_VERSION >= 3
+ if (level == -1) {
+ if (strchr(__Pyx_MODULE_NAME, '.') != NULL) {
+ module = PyImport_ImportModuleLevelObject(
+ name, __pyx_d, empty_dict, from_list, 1);
+ if (unlikely(!module)) {
+ if (unlikely(!PyErr_ExceptionMatches(PyExc_ImportError)))
+ goto bad;
+ PyErr_Clear();
+ }
+ }
+ level = 0;
+ }
+ #endif
+ if (!module) {
+ #if PY_MAJOR_VERSION < 3
+ PyObject *py_level = PyInt_FromLong(level);
+ if (unlikely(!py_level))
+ goto bad;
+ module = PyObject_CallFunctionObjArgs(py_import,
+ name, __pyx_d, empty_dict, from_list, py_level, (PyObject *)NULL);
+ Py_DECREF(py_level);
+ #else
+ module = PyImport_ImportModuleLevelObject(
+ name, __pyx_d, empty_dict, from_list, level);
+ #endif
+ }
+ }
+bad:
+ Py_XDECREF(empty_dict);
+ Py_XDECREF(empty_list);
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(py_import);
+ #endif
+ return module;
+}
+
+/* ImportDottedModule */
+ #if PY_MAJOR_VERSION >= 3
+static PyObject *__Pyx__ImportDottedModule_Error(PyObject *name, PyObject *parts_tuple, Py_ssize_t count) {
+ PyObject *partial_name = NULL, *slice = NULL, *sep = NULL;
+ if (unlikely(PyErr_Occurred())) {
+ PyErr_Clear();
+ }
+ if (likely(PyTuple_GET_SIZE(parts_tuple) == count)) {
+ partial_name = name;
+ } else {
+ slice = PySequence_GetSlice(parts_tuple, 0, count);
+ if (unlikely(!slice))
+ goto bad;
+ sep = PyUnicode_FromStringAndSize(".", 1);
+ if (unlikely(!sep))
+ goto bad;
+ partial_name = PyUnicode_Join(sep, slice);
+ }
+ PyErr_Format(
+#if PY_MAJOR_VERSION < 3
+ PyExc_ImportError,
+ "No module named '%s'", PyString_AS_STRING(partial_name));
+#else
+#if PY_VERSION_HEX >= 0x030600B1
+ PyExc_ModuleNotFoundError,
+#else
+ PyExc_ImportError,
+#endif
+ "No module named '%U'", partial_name);
+#endif
+bad:
+ Py_XDECREF(sep);
+ Py_XDECREF(slice);
+ Py_XDECREF(partial_name);
+ return NULL;
+}
+#endif
+#if PY_MAJOR_VERSION >= 3
+static PyObject *__Pyx__ImportDottedModule_Lookup(PyObject *name) {
+ PyObject *imported_module;
+#if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400)
+ PyObject *modules = PyImport_GetModuleDict();
+ if (unlikely(!modules))
+ return NULL;
+ imported_module = __Pyx_PyDict_GetItemStr(modules, name);
+ Py_XINCREF(imported_module);
+#else
+ imported_module = PyImport_GetModule(name);
+#endif
+ return imported_module;
+}
+#endif
+#if PY_MAJOR_VERSION >= 3
+static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple) {
+ Py_ssize_t i, nparts;
+ nparts = PyTuple_GET_SIZE(parts_tuple);
+ for (i=1; i < nparts && module; i++) {
+ PyObject *part, *submodule;
+#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ part = PyTuple_GET_ITEM(parts_tuple, i);
+#else
+ part = PySequence_ITEM(parts_tuple, i);
+#endif
+ submodule = __Pyx_PyObject_GetAttrStrNoError(module, part);
+#if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS)
+ Py_DECREF(part);
+#endif
+ Py_DECREF(module);
+ module = submodule;
+ }
+ if (unlikely(!module)) {
+ return __Pyx__ImportDottedModule_Error(name, parts_tuple, i);
+ }
+ return module;
+}
+#endif
+static PyObject *__Pyx__ImportDottedModule(PyObject *name, PyObject *parts_tuple) {
+#if PY_MAJOR_VERSION < 3
+ PyObject *module, *from_list, *star = __pyx_n_s__10;
+ CYTHON_UNUSED_VAR(parts_tuple);
+ from_list = PyList_New(1);
+ if (unlikely(!from_list))
+ return NULL;
+ Py_INCREF(star);
+ PyList_SET_ITEM(from_list, 0, star);
+ module = __Pyx_Import(name, from_list, 0);
+ Py_DECREF(from_list);
+ return module;
+#else
+ PyObject *imported_module;
+ PyObject *module = __Pyx_Import(name, NULL, 0);
+ if (!parts_tuple || unlikely(!module))
+ return module;
+ imported_module = __Pyx__ImportDottedModule_Lookup(name);
+ if (likely(imported_module)) {
+ Py_DECREF(module);
+ return imported_module;
+ }
+ PyErr_Clear();
+ return __Pyx_ImportDottedModule_WalkParts(module, name, parts_tuple);
+#endif
+}
+static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple) {
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030400B1
+ PyObject *module = __Pyx__ImportDottedModule_Lookup(name);
+ if (likely(module)) {
+ PyObject *spec = __Pyx_PyObject_GetAttrStrNoError(module, __pyx_n_s_spec);
+ if (likely(spec)) {
+ PyObject *unsafe = __Pyx_PyObject_GetAttrStrNoError(spec, __pyx_n_s_initializing);
+ if (likely(!unsafe || !__Pyx_PyObject_IsTrue(unsafe))) {
+ Py_DECREF(spec);
+ spec = NULL;
+ }
+ Py_XDECREF(unsafe);
+ }
+ if (likely(!spec)) {
+ PyErr_Clear();
+ return module;
+ }
+ Py_DECREF(spec);
+ Py_DECREF(module);
+ } else if (PyErr_Occurred()) {
+ PyErr_Clear();
+ }
+#endif
+ return __Pyx__ImportDottedModule(name, parts_tuple);
+}
+
+/* FixUpExtensionType */
+ #if CYTHON_USE_TYPE_SPECS
+static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type) {
+#if PY_VERSION_HEX > 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API
+ CYTHON_UNUSED_VAR(spec);
+ CYTHON_UNUSED_VAR(type);
+#else
+ const PyType_Slot *slot = spec->slots;
+ while (slot && slot->slot && slot->slot != Py_tp_members)
+ slot++;
+ if (slot && slot->slot == Py_tp_members) {
+ int changed = 0;
+#if !(PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON)
+ const
+#endif
+ PyMemberDef *memb = (PyMemberDef*) slot->pfunc;
+ while (memb && memb->name) {
+ if (memb->name[0] == '_' && memb->name[1] == '_') {
+#if PY_VERSION_HEX < 0x030900b1
+ if (strcmp(memb->name, "__weaklistoffset__") == 0) {
+ assert(memb->type == T_PYSSIZET);
+ assert(memb->flags == READONLY);
+ type->tp_weaklistoffset = memb->offset;
+ changed = 1;
+ }
+ else if (strcmp(memb->name, "__dictoffset__") == 0) {
+ assert(memb->type == T_PYSSIZET);
+ assert(memb->flags == READONLY);
+ type->tp_dictoffset = memb->offset;
+ changed = 1;
+ }
+#if CYTHON_METH_FASTCALL
+ else if (strcmp(memb->name, "__vectorcalloffset__") == 0) {
+ assert(memb->type == T_PYSSIZET);
+ assert(memb->flags == READONLY);
+#if PY_VERSION_HEX >= 0x030800b4
+ type->tp_vectorcall_offset = memb->offset;
+#else
+ type->tp_print = (printfunc) memb->offset;
+#endif
+ changed = 1;
+ }
+#endif
+#else
+ if ((0));
+#endif
+#if PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON
+ else if (strcmp(memb->name, "__module__") == 0) {
+ PyObject *descr;
+ assert(memb->type == T_OBJECT);
+ assert(memb->flags == 0 || memb->flags == READONLY);
+ descr = PyDescr_NewMember(type, memb);
+ if (unlikely(!descr))
+ return -1;
+ if (unlikely(PyDict_SetItem(type->tp_dict, PyDescr_NAME(descr), descr) < 0)) {
+ Py_DECREF(descr);
+ return -1;
+ }
+ Py_DECREF(descr);
+ changed = 1;
+ }
+#endif
+ }
+ memb++;
+ }
+ if (changed)
+ PyType_Modified(type);
+ }
+#endif
+ return 0;
+}
+#endif
+
+/* FetchSharedCythonModule */
+ static PyObject *__Pyx_FetchSharedCythonABIModule(void) {
+ return __Pyx_PyImport_AddModuleRef((char*) __PYX_ABI_MODULE_NAME);
+}
+
+/* FetchCommonType */
+ static int __Pyx_VerifyCachedType(PyObject *cached_type,
+ const char *name,
+ Py_ssize_t basicsize,
+ Py_ssize_t expected_basicsize) {
+ if (!PyType_Check(cached_type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Shared Cython type %.200s is not a type object", name);
+ return -1;
+ }
+ if (basicsize != expected_basicsize) {
+ PyErr_Format(PyExc_TypeError,
+ "Shared Cython type %.200s has the wrong size, try recompiling",
+ name);
+ return -1;
+ }
+ return 0;
+}
+#if !CYTHON_USE_TYPE_SPECS
+static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type) {
+ PyObject* abi_module;
+ const char* object_name;
+ PyTypeObject *cached_type = NULL;
+ abi_module = __Pyx_FetchSharedCythonABIModule();
+ if (!abi_module) return NULL;
+ object_name = strrchr(type->tp_name, '.');
+ object_name = object_name ? object_name+1 : type->tp_name;
+ cached_type = (PyTypeObject*) PyObject_GetAttrString(abi_module, object_name);
+ if (cached_type) {
+ if (__Pyx_VerifyCachedType(
+ (PyObject *)cached_type,
+ object_name,
+ cached_type->tp_basicsize,
+ type->tp_basicsize) < 0) {
+ goto bad;
+ }
+ goto done;
+ }
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad;
+ PyErr_Clear();
+ if (PyType_Ready(type) < 0) goto bad;
+ if (PyObject_SetAttrString(abi_module, object_name, (PyObject *)type) < 0)
+ goto bad;
+ Py_INCREF(type);
+ cached_type = type;
+done:
+ Py_DECREF(abi_module);
+ return cached_type;
+bad:
+ Py_XDECREF(cached_type);
+ cached_type = NULL;
+ goto done;
+}
+#else
+static PyTypeObject *__Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases) {
+ PyObject *abi_module, *cached_type = NULL;
+ const char* object_name = strrchr(spec->name, '.');
+ object_name = object_name ? object_name+1 : spec->name;
+ abi_module = __Pyx_FetchSharedCythonABIModule();
+ if (!abi_module) return NULL;
+ cached_type = PyObject_GetAttrString(abi_module, object_name);
+ if (cached_type) {
+ Py_ssize_t basicsize;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ PyObject *py_basicsize;
+ py_basicsize = PyObject_GetAttrString(cached_type, "__basicsize__");
+ if (unlikely(!py_basicsize)) goto bad;
+ basicsize = PyLong_AsSsize_t(py_basicsize);
+ Py_DECREF(py_basicsize);
+ py_basicsize = 0;
+ if (unlikely(basicsize == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad;
+#else
+ basicsize = likely(PyType_Check(cached_type)) ? ((PyTypeObject*) cached_type)->tp_basicsize : -1;
+#endif
+ if (__Pyx_VerifyCachedType(
+ cached_type,
+ object_name,
+ basicsize,
+ spec->basicsize) < 0) {
+ goto bad;
+ }
+ goto done;
+ }
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad;
+ PyErr_Clear();
+ CYTHON_UNUSED_VAR(module);
+ cached_type = __Pyx_PyType_FromModuleAndSpec(abi_module, spec, bases);
+ if (unlikely(!cached_type)) goto bad;
+ if (unlikely(__Pyx_fix_up_extension_type_from_spec(spec, (PyTypeObject *) cached_type) < 0)) goto bad;
+ if (PyObject_SetAttrString(abi_module, object_name, cached_type) < 0) goto bad;
+done:
+ Py_DECREF(abi_module);
+ assert(cached_type == NULL || PyType_Check(cached_type));
+ return (PyTypeObject *) cached_type;
+bad:
+ Py_XDECREF(cached_type);
+ cached_type = NULL;
+ goto done;
+}
+#endif
+
+/* PyVectorcallFastCallDict */
+ #if CYTHON_METH_FASTCALL
+static PyObject *__Pyx_PyVectorcall_FastCallDict_kw(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw)
+{
+ PyObject *res = NULL;
+ PyObject *kwnames;
+ PyObject **newargs;
+ PyObject **kwvalues;
+ Py_ssize_t i, pos;
+ size_t j;
+ PyObject *key, *value;
+ unsigned long keys_are_strings;
+ Py_ssize_t nkw = PyDict_GET_SIZE(kw);
+ newargs = (PyObject **)PyMem_Malloc((nargs + (size_t)nkw) * sizeof(args[0]));
+ if (unlikely(newargs == NULL)) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ for (j = 0; j < nargs; j++) newargs[j] = args[j];
+ kwnames = PyTuple_New(nkw);
+ if (unlikely(kwnames == NULL)) {
+ PyMem_Free(newargs);
+ return NULL;
+ }
+ kwvalues = newargs + nargs;
+ pos = i = 0;
+ keys_are_strings = Py_TPFLAGS_UNICODE_SUBCLASS;
+ while (PyDict_Next(kw, &pos, &key, &value)) {
+ keys_are_strings &= Py_TYPE(key)->tp_flags;
+ Py_INCREF(key);
+ Py_INCREF(value);
+ PyTuple_SET_ITEM(kwnames, i, key);
+ kwvalues[i] = value;
+ i++;
+ }
+ if (unlikely(!keys_are_strings)) {
+ PyErr_SetString(PyExc_TypeError, "keywords must be strings");
+ goto cleanup;
+ }
+ res = vc(func, newargs, nargs, kwnames);
+cleanup:
+ Py_DECREF(kwnames);
+ for (i = 0; i < nkw; i++)
+ Py_DECREF(kwvalues[i]);
+ PyMem_Free(newargs);
+ return res;
+}
+static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw)
+{
+ if (likely(kw == NULL) || PyDict_GET_SIZE(kw) == 0) {
+ return vc(func, args, nargs, NULL);
+ }
+ return __Pyx_PyVectorcall_FastCallDict_kw(func, vc, args, nargs, kw);
+}
+#endif
+
+/* CythonFunctionShared */
+ #if CYTHON_COMPILING_IN_LIMITED_API
+static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) {
+ if (__Pyx_CyFunction_Check(func)) {
+ return PyCFunction_GetFunction(((__pyx_CyFunctionObject*)func)->func) == (PyCFunction) cfunc;
+ } else if (PyCFunction_Check(func)) {
+ return PyCFunction_GetFunction(func) == (PyCFunction) cfunc;
+ }
+ return 0;
+}
+#else
+static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) {
+ return __Pyx_CyOrPyCFunction_Check(func) && __Pyx_CyOrPyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc;
+}
+#endif
+static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj) {
+#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API
+ __Pyx_Py_XDECREF_SET(
+ __Pyx_CyFunction_GetClassObj(f),
+ ((classobj) ? __Pyx_NewRef(classobj) : NULL));
+#else
+ __Pyx_Py_XDECREF_SET(
+ ((PyCMethodObject *) (f))->mm_class,
+ (PyTypeObject*)((classobj) ? __Pyx_NewRef(classobj) : NULL));
+#endif
+}
+static PyObject *
+__Pyx_CyFunction_get_doc(__pyx_CyFunctionObject *op, void *closure)
+{
+ CYTHON_UNUSED_VAR(closure);
+ if (unlikely(op->func_doc == NULL)) {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ op->func_doc = PyObject_GetAttrString(op->func, "__doc__");
+ if (unlikely(!op->func_doc)) return NULL;
+#else
+ if (((PyCFunctionObject*)op)->m_ml->ml_doc) {
+#if PY_MAJOR_VERSION >= 3
+ op->func_doc = PyUnicode_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc);
+#else
+ op->func_doc = PyString_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc);
+#endif
+ if (unlikely(op->func_doc == NULL))
+ return NULL;
+ } else {
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+#endif
+ }
+ Py_INCREF(op->func_doc);
+ return op->func_doc;
+}
+static int
+__Pyx_CyFunction_set_doc(__pyx_CyFunctionObject *op, PyObject *value, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ if (value == NULL) {
+ value = Py_None;
+ }
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->func_doc, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_name(__pyx_CyFunctionObject *op, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(op->func_name == NULL)) {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ op->func_name = PyObject_GetAttrString(op->func, "__name__");
+#elif PY_MAJOR_VERSION >= 3
+ op->func_name = PyUnicode_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name);
+#else
+ op->func_name = PyString_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name);
+#endif
+ if (unlikely(op->func_name == NULL))
+ return NULL;
+ }
+ Py_INCREF(op->func_name);
+ return op->func_name;
+}
+static int
+__Pyx_CyFunction_set_name(__pyx_CyFunctionObject *op, PyObject *value, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+#if PY_MAJOR_VERSION >= 3
+ if (unlikely(value == NULL || !PyUnicode_Check(value)))
+#else
+ if (unlikely(value == NULL || !PyString_Check(value)))
+#endif
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "__name__ must be set to a string object");
+ return -1;
+ }
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->func_name, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_qualname(__pyx_CyFunctionObject *op, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ Py_INCREF(op->func_qualname);
+ return op->func_qualname;
+}
+static int
+__Pyx_CyFunction_set_qualname(__pyx_CyFunctionObject *op, PyObject *value, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+#if PY_MAJOR_VERSION >= 3
+ if (unlikely(value == NULL || !PyUnicode_Check(value)))
+#else
+ if (unlikely(value == NULL || !PyString_Check(value)))
+#endif
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "__qualname__ must be set to a string object");
+ return -1;
+ }
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->func_qualname, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_dict(__pyx_CyFunctionObject *op, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(op->func_dict == NULL)) {
+ op->func_dict = PyDict_New();
+ if (unlikely(op->func_dict == NULL))
+ return NULL;
+ }
+ Py_INCREF(op->func_dict);
+ return op->func_dict;
+}
+static int
+__Pyx_CyFunction_set_dict(__pyx_CyFunctionObject *op, PyObject *value, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(value == NULL)) {
+ PyErr_SetString(PyExc_TypeError,
+ "function's dictionary may not be deleted");
+ return -1;
+ }
+ if (unlikely(!PyDict_Check(value))) {
+ PyErr_SetString(PyExc_TypeError,
+ "setting function's dictionary to a non-dict");
+ return -1;
+ }
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->func_dict, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_globals(__pyx_CyFunctionObject *op, void *context)
+{
+ CYTHON_UNUSED_VAR(context);
+ Py_INCREF(op->func_globals);
+ return op->func_globals;
+}
+static PyObject *
+__Pyx_CyFunction_get_closure(__pyx_CyFunctionObject *op, void *context)
+{
+ CYTHON_UNUSED_VAR(op);
+ CYTHON_UNUSED_VAR(context);
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+static PyObject *
+__Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op, void *context)
+{
+ PyObject* result = (op->func_code) ? op->func_code : Py_None;
+ CYTHON_UNUSED_VAR(context);
+ Py_INCREF(result);
+ return result;
+}
+static int
+__Pyx_CyFunction_init_defaults(__pyx_CyFunctionObject *op) {
+ int result = 0;
+ PyObject *res = op->defaults_getter((PyObject *) op);
+ if (unlikely(!res))
+ return -1;
+ #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ op->defaults_tuple = PyTuple_GET_ITEM(res, 0);
+ Py_INCREF(op->defaults_tuple);
+ op->defaults_kwdict = PyTuple_GET_ITEM(res, 1);
+ Py_INCREF(op->defaults_kwdict);
+ #else
+ op->defaults_tuple = __Pyx_PySequence_ITEM(res, 0);
+ if (unlikely(!op->defaults_tuple)) result = -1;
+ else {
+ op->defaults_kwdict = __Pyx_PySequence_ITEM(res, 1);
+ if (unlikely(!op->defaults_kwdict)) result = -1;
+ }
+ #endif
+ Py_DECREF(res);
+ return result;
+}
+static int
+__Pyx_CyFunction_set_defaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) {
+ CYTHON_UNUSED_VAR(context);
+ if (!value) {
+ value = Py_None;
+ } else if (unlikely(value != Py_None && !PyTuple_Check(value))) {
+ PyErr_SetString(PyExc_TypeError,
+ "__defaults__ must be set to a tuple object");
+ return -1;
+ }
+ PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__defaults__ will not "
+ "currently affect the values used in function calls", 1);
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->defaults_tuple, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op, void *context) {
+ PyObject* result = op->defaults_tuple;
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(!result)) {
+ if (op->defaults_getter) {
+ if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL;
+ result = op->defaults_tuple;
+ } else {
+ result = Py_None;
+ }
+ }
+ Py_INCREF(result);
+ return result;
+}
+static int
+__Pyx_CyFunction_set_kwdefaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) {
+ CYTHON_UNUSED_VAR(context);
+ if (!value) {
+ value = Py_None;
+ } else if (unlikely(value != Py_None && !PyDict_Check(value))) {
+ PyErr_SetString(PyExc_TypeError,
+ "__kwdefaults__ must be set to a dict object");
+ return -1;
+ }
+ PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__kwdefaults__ will not "
+ "currently affect the values used in function calls", 1);
+ Py_INCREF(value);
+ __Pyx_Py_XDECREF_SET(op->defaults_kwdict, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_kwdefaults(__pyx_CyFunctionObject *op, void *context) {
+ PyObject* result = op->defaults_kwdict;
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(!result)) {
+ if (op->defaults_getter) {
+ if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL;
+ result = op->defaults_kwdict;
+ } else {
+ result = Py_None;
+ }
+ }
+ Py_INCREF(result);
+ return result;
+}
+static int
+__Pyx_CyFunction_set_annotations(__pyx_CyFunctionObject *op, PyObject* value, void *context) {
+ CYTHON_UNUSED_VAR(context);
+ if (!value || value == Py_None) {
+ value = NULL;
+ } else if (unlikely(!PyDict_Check(value))) {
+ PyErr_SetString(PyExc_TypeError,
+ "__annotations__ must be set to a dict object");
+ return -1;
+ }
+ Py_XINCREF(value);
+ __Pyx_Py_XDECREF_SET(op->func_annotations, value);
+ return 0;
+}
+static PyObject *
+__Pyx_CyFunction_get_annotations(__pyx_CyFunctionObject *op, void *context) {
+ PyObject* result = op->func_annotations;
+ CYTHON_UNUSED_VAR(context);
+ if (unlikely(!result)) {
+ result = PyDict_New();
+ if (unlikely(!result)) return NULL;
+ op->func_annotations = result;
+ }
+ Py_INCREF(result);
+ return result;
+}
+static PyObject *
+__Pyx_CyFunction_get_is_coroutine(__pyx_CyFunctionObject *op, void *context) {
+ int is_coroutine;
+ CYTHON_UNUSED_VAR(context);
+ if (op->func_is_coroutine) {
+ return __Pyx_NewRef(op->func_is_coroutine);
+ }
+ is_coroutine = op->flags & __Pyx_CYFUNCTION_COROUTINE;
+#if PY_VERSION_HEX >= 0x03050000
+ if (is_coroutine) {
+ PyObject *module, *fromlist, *marker = __pyx_n_s_is_coroutine;
+ fromlist = PyList_New(1);
+ if (unlikely(!fromlist)) return NULL;
+ Py_INCREF(marker);
+#if CYTHON_ASSUME_SAFE_MACROS
+ PyList_SET_ITEM(fromlist, 0, marker);
+#else
+ if (unlikely(PyList_SetItem(fromlist, 0, marker) < 0)) {
+ Py_DECREF(marker);
+ Py_DECREF(fromlist);
+ return NULL;
+ }
+#endif
+ module = PyImport_ImportModuleLevelObject(__pyx_n_s_asyncio_coroutines, NULL, NULL, fromlist, 0);
+ Py_DECREF(fromlist);
+ if (unlikely(!module)) goto ignore;
+ op->func_is_coroutine = __Pyx_PyObject_GetAttrStr(module, marker);
+ Py_DECREF(module);
+ if (likely(op->func_is_coroutine)) {
+ return __Pyx_NewRef(op->func_is_coroutine);
+ }
+ignore:
+ PyErr_Clear();
+ }
+#endif
+ op->func_is_coroutine = __Pyx_PyBool_FromLong(is_coroutine);
+ return __Pyx_NewRef(op->func_is_coroutine);
+}
+#if CYTHON_COMPILING_IN_LIMITED_API
+static PyObject *
+__Pyx_CyFunction_get_module(__pyx_CyFunctionObject *op, void *context) {
+ CYTHON_UNUSED_VAR(context);
+ return PyObject_GetAttrString(op->func, "__module__");
+}
+static int
+__Pyx_CyFunction_set_module(__pyx_CyFunctionObject *op, PyObject* value, void *context) {
+ CYTHON_UNUSED_VAR(context);
+ return PyObject_SetAttrString(op->func, "__module__", value);
+}
+#endif
+static PyGetSetDef __pyx_CyFunction_getsets[] = {
+ {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0},
+ {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0},
+ {(char *) "func_name", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0},
+ {(char *) "__name__", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0},
+ {(char *) "__qualname__", (getter)__Pyx_CyFunction_get_qualname, (setter)__Pyx_CyFunction_set_qualname, 0, 0},
+ {(char *) "func_dict", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0},
+ {(char *) "__dict__", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0},
+ {(char *) "func_globals", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0},
+ {(char *) "__globals__", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0},
+ {(char *) "func_closure", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0},
+ {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0},
+ {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0},
+ {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0},
+ {(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0},
+ {(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0},
+ {(char *) "__kwdefaults__", (getter)__Pyx_CyFunction_get_kwdefaults, (setter)__Pyx_CyFunction_set_kwdefaults, 0, 0},
+ {(char *) "__annotations__", (getter)__Pyx_CyFunction_get_annotations, (setter)__Pyx_CyFunction_set_annotations, 0, 0},
+ {(char *) "_is_coroutine", (getter)__Pyx_CyFunction_get_is_coroutine, 0, 0, 0},
+#if CYTHON_COMPILING_IN_LIMITED_API
+ {"__module__", (getter)__Pyx_CyFunction_get_module, (setter)__Pyx_CyFunction_set_module, 0, 0},
+#endif
+ {0, 0, 0, 0, 0}
+};
+static PyMemberDef __pyx_CyFunction_members[] = {
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ {(char *) "__module__", T_OBJECT, offsetof(PyCFunctionObject, m_module), 0, 0},
+#endif
+#if CYTHON_USE_TYPE_SPECS
+ {(char *) "__dictoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_dict), READONLY, 0},
+#if CYTHON_METH_FASTCALL
+#if CYTHON_BACKPORT_VECTORCALL
+ {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_vectorcall), READONLY, 0},
+#else
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(PyCFunctionObject, vectorcall), READONLY, 0},
+#endif
+#endif
+#endif
+#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API
+ {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_weakreflist), READONLY, 0},
+#else
+ {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(PyCFunctionObject, m_weakreflist), READONLY, 0},
+#endif
+#endif
+ {0, 0, 0, 0, 0}
+};
+static PyObject *
+__Pyx_CyFunction_reduce(__pyx_CyFunctionObject *m, PyObject *args)
+{
+ CYTHON_UNUSED_VAR(args);
+#if PY_MAJOR_VERSION >= 3
+ Py_INCREF(m->func_qualname);
+ return m->func_qualname;
+#else
+ return PyString_FromString(((PyCFunctionObject*)m)->m_ml->ml_name);
+#endif
+}
+static PyMethodDef __pyx_CyFunction_methods[] = {
+ {"__reduce__", (PyCFunction)__Pyx_CyFunction_reduce, METH_VARARGS, 0},
+ {0, 0, 0, 0}
+};
+#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API
+#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func_weakreflist)
+#else
+#define __Pyx_CyFunction_weakreflist(cyfunc) (((PyCFunctionObject*)cyfunc)->m_weakreflist)
+#endif
+static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject *op, PyMethodDef *ml, int flags, PyObject* qualname,
+ PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) {
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ PyCFunctionObject *cf = (PyCFunctionObject*) op;
+#endif
+ if (unlikely(op == NULL))
+ return NULL;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ op->func = PyCFunction_NewEx(ml, (PyObject*)op, module);
+ if (unlikely(!op->func)) return NULL;
+#endif
+ op->flags = flags;
+ __Pyx_CyFunction_weakreflist(op) = NULL;
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ cf->m_ml = ml;
+ cf->m_self = (PyObject *) op;
+#endif
+ Py_XINCREF(closure);
+ op->func_closure = closure;
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ Py_XINCREF(module);
+ cf->m_module = module;
+#endif
+ op->func_dict = NULL;
+ op->func_name = NULL;
+ Py_INCREF(qualname);
+ op->func_qualname = qualname;
+ op->func_doc = NULL;
+#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API
+ op->func_classobj = NULL;
+#else
+ ((PyCMethodObject*)op)->mm_class = NULL;
+#endif
+ op->func_globals = globals;
+ Py_INCREF(op->func_globals);
+ Py_XINCREF(code);
+ op->func_code = code;
+ op->defaults_pyobjects = 0;
+ op->defaults_size = 0;
+ op->defaults = NULL;
+ op->defaults_tuple = NULL;
+ op->defaults_kwdict = NULL;
+ op->defaults_getter = NULL;
+ op->func_annotations = NULL;
+ op->func_is_coroutine = NULL;
+#if CYTHON_METH_FASTCALL
+ switch (ml->ml_flags & (METH_VARARGS | METH_FASTCALL | METH_NOARGS | METH_O | METH_KEYWORDS | METH_METHOD)) {
+ case METH_NOARGS:
+ __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_NOARGS;
+ break;
+ case METH_O:
+ __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_O;
+ break;
+ case METH_METHOD | METH_FASTCALL | METH_KEYWORDS:
+ __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD;
+ break;
+ case METH_FASTCALL | METH_KEYWORDS:
+ __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS;
+ break;
+ case METH_VARARGS | METH_KEYWORDS:
+ __Pyx_CyFunction_func_vectorcall(op) = NULL;
+ break;
+ default:
+ PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction");
+ Py_DECREF(op);
+ return NULL;
+ }
+#endif
+ return (PyObject *) op;
+}
+static int
+__Pyx_CyFunction_clear(__pyx_CyFunctionObject *m)
+{
+ Py_CLEAR(m->func_closure);
+#if CYTHON_COMPILING_IN_LIMITED_API
+ Py_CLEAR(m->func);
+#else
+ Py_CLEAR(((PyCFunctionObject*)m)->m_module);
+#endif
+ Py_CLEAR(m->func_dict);
+ Py_CLEAR(m->func_name);
+ Py_CLEAR(m->func_qualname);
+ Py_CLEAR(m->func_doc);
+ Py_CLEAR(m->func_globals);
+ Py_CLEAR(m->func_code);
+#if !CYTHON_COMPILING_IN_LIMITED_API
+#if PY_VERSION_HEX < 0x030900B1
+ Py_CLEAR(__Pyx_CyFunction_GetClassObj(m));
+#else
+ {
+ PyObject *cls = (PyObject*) ((PyCMethodObject *) (m))->mm_class;
+ ((PyCMethodObject *) (m))->mm_class = NULL;
+ Py_XDECREF(cls);
+ }
+#endif
+#endif
+ Py_CLEAR(m->defaults_tuple);
+ Py_CLEAR(m->defaults_kwdict);
+ Py_CLEAR(m->func_annotations);
+ Py_CLEAR(m->func_is_coroutine);
+ if (m->defaults) {
+ PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
+ int i;
+ for (i = 0; i < m->defaults_pyobjects; i++)
+ Py_XDECREF(pydefaults[i]);
+ PyObject_Free(m->defaults);
+ m->defaults = NULL;
+ }
+ return 0;
+}
+static void __Pyx__CyFunction_dealloc(__pyx_CyFunctionObject *m)
+{
+ if (__Pyx_CyFunction_weakreflist(m) != NULL)
+ PyObject_ClearWeakRefs((PyObject *) m);
+ __Pyx_CyFunction_clear(m);
+ __Pyx_PyHeapTypeObject_GC_Del(m);
+}
+static void __Pyx_CyFunction_dealloc(__pyx_CyFunctionObject *m)
+{
+ PyObject_GC_UnTrack(m);
+ __Pyx__CyFunction_dealloc(m);
+}
+static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, void *arg)
+{
+ Py_VISIT(m->func_closure);
+#if CYTHON_COMPILING_IN_LIMITED_API
+ Py_VISIT(m->func);
+#else
+ Py_VISIT(((PyCFunctionObject*)m)->m_module);
+#endif
+ Py_VISIT(m->func_dict);
+ Py_VISIT(m->func_name);
+ Py_VISIT(m->func_qualname);
+ Py_VISIT(m->func_doc);
+ Py_VISIT(m->func_globals);
+ Py_VISIT(m->func_code);
+#if !CYTHON_COMPILING_IN_LIMITED_API
+ Py_VISIT(__Pyx_CyFunction_GetClassObj(m));
+#endif
+ Py_VISIT(m->defaults_tuple);
+ Py_VISIT(m->defaults_kwdict);
+ Py_VISIT(m->func_is_coroutine);
+ if (m->defaults) {
+ PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
+ int i;
+ for (i = 0; i < m->defaults_pyobjects; i++)
+ Py_VISIT(pydefaults[i]);
+ }
+ return 0;
+}
+static PyObject*
+__Pyx_CyFunction_repr(__pyx_CyFunctionObject *op)
+{
+#if PY_MAJOR_VERSION >= 3
+ return PyUnicode_FromFormat("",
+ op->func_qualname, (void *)op);
+#else
+ return PyString_FromFormat("",
+ PyString_AsString(op->func_qualname), (void *)op);
+#endif
+}
+static PyObject * __Pyx_CyFunction_CallMethod(PyObject *func, PyObject *self, PyObject *arg, PyObject *kw) {
+#if CYTHON_COMPILING_IN_LIMITED_API
+ PyObject *f = ((__pyx_CyFunctionObject*)func)->func;
+ PyObject *py_name = NULL;
+ PyCFunction meth;
+ int flags;
+ meth = PyCFunction_GetFunction(f);
+ if (unlikely(!meth)) return NULL;
+ flags = PyCFunction_GetFlags(f);
+ if (unlikely(flags < 0)) return NULL;
+#else
+ PyCFunctionObject* f = (PyCFunctionObject*)func;
+ PyCFunction meth = f->m_ml->ml_meth;
+ int flags = f->m_ml->ml_flags;
+#endif
+ Py_ssize_t size;
+ switch (flags & (METH_VARARGS | METH_KEYWORDS | METH_NOARGS | METH_O)) {
+ case METH_VARARGS:
+ if (likely(kw == NULL || PyDict_Size(kw) == 0))
+ return (*meth)(self, arg);
+ break;
+ case METH_VARARGS | METH_KEYWORDS:
+ return (*(PyCFunctionWithKeywords)(void*)meth)(self, arg, kw);
+ case METH_NOARGS:
+ if (likely(kw == NULL || PyDict_Size(kw) == 0)) {
+#if CYTHON_ASSUME_SAFE_MACROS
+ size = PyTuple_GET_SIZE(arg);
+#else
+ size = PyTuple_Size(arg);
+ if (unlikely(size < 0)) return NULL;
+#endif
+ if (likely(size == 0))
+ return (*meth)(self, NULL);
+#if CYTHON_COMPILING_IN_LIMITED_API
+ py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL);
+ if (!py_name) return NULL;
+ PyErr_Format(PyExc_TypeError,
+ "%.200S() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ py_name, size);
+ Py_DECREF(py_name);
+#else
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ f->m_ml->ml_name, size);
+#endif
+ return NULL;
+ }
+ break;
+ case METH_O:
+ if (likely(kw == NULL || PyDict_Size(kw) == 0)) {
+#if CYTHON_ASSUME_SAFE_MACROS
+ size = PyTuple_GET_SIZE(arg);
+#else
+ size = PyTuple_Size(arg);
+ if (unlikely(size < 0)) return NULL;
+#endif
+ if (likely(size == 1)) {
+ PyObject *result, *arg0;
+ #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+ arg0 = PyTuple_GET_ITEM(arg, 0);
+ #else
+ arg0 = __Pyx_PySequence_ITEM(arg, 0); if (unlikely(!arg0)) return NULL;
+ #endif
+ result = (*meth)(self, arg0);
+ #if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS)
+ Py_DECREF(arg0);
+ #endif
+ return result;
+ }
+#if CYTHON_COMPILING_IN_LIMITED_API
+ py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL);
+ if (!py_name) return NULL;
+ PyErr_Format(PyExc_TypeError,
+ "%.200S() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ py_name, size);
+ Py_DECREF(py_name);
+#else
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ f->m_ml->ml_name, size);
+#endif
+ return NULL;
+ }
+ break;
+ default:
+ PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction");
+ return NULL;
+ }
+#if CYTHON_COMPILING_IN_LIMITED_API
+ py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL);
+ if (!py_name) return NULL;
+ PyErr_Format(PyExc_TypeError, "%.200S() takes no keyword arguments",
+ py_name);
+ Py_DECREF(py_name);
+#else
+ PyErr_Format(PyExc_TypeError, "%.200s() takes no keyword arguments",
+ f->m_ml->ml_name);
+#endif
+ return NULL;
+}
+static CYTHON_INLINE PyObject *__Pyx_CyFunction_Call(PyObject *func, PyObject *arg, PyObject *kw) {
+ PyObject *self, *result;
+#if CYTHON_COMPILING_IN_LIMITED_API
+ self = PyCFunction_GetSelf(((__pyx_CyFunctionObject*)func)->func);
+ if (unlikely(!self) && PyErr_Occurred()) return NULL;
+#else
+ self = ((PyCFunctionObject*)func)->m_self;
+#endif
+ result = __Pyx_CyFunction_CallMethod(func, self, arg, kw);
+ return result;
+}
+static PyObject *__Pyx_CyFunction_CallAsMethod(PyObject *func, PyObject *args, PyObject *kw) {
+ PyObject *result;
+ __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *) func;
+#if CYTHON_METH_FASTCALL
+ __pyx_vectorcallfunc vc = __Pyx_CyFunction_func_vectorcall(cyfunc);
+ if (vc) {
+#if CYTHON_ASSUME_SAFE_MACROS
+ return __Pyx_PyVectorcall_FastCallDict(func, vc, &PyTuple_GET_ITEM(args, 0), (size_t)PyTuple_GET_SIZE(args), kw);
+#else
+ (void) &__Pyx_PyVectorcall_FastCallDict;
+ return PyVectorcall_Call(func, args, kw);
+#endif
+ }
+#endif
+ if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) {
+ Py_ssize_t argc;
+ PyObject *new_args;
+ PyObject *self;
+#if CYTHON_ASSUME_SAFE_MACROS
+ argc = PyTuple_GET_SIZE(args);
+#else
+ argc = PyTuple_Size(args);
+ if (unlikely(!argc) < 0) return NULL;
+#endif
+ new_args = PyTuple_GetSlice(args, 1, argc);
+ if (unlikely(!new_args))
+ return NULL;
+ self = PyTuple_GetItem(args, 0);
+ if (unlikely(!self)) {
+ Py_DECREF(new_args);
+#if PY_MAJOR_VERSION > 2
+ PyErr_Format(PyExc_TypeError,
+ "unbound method %.200S() needs an argument",
+ cyfunc->func_qualname);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "unbound method needs an argument");
+#endif
+ return NULL;
+ }
+ result = __Pyx_CyFunction_CallMethod(func, self, new_args, kw);
+ Py_DECREF(new_args);
+ } else {
+ result = __Pyx_CyFunction_Call(func, args, kw);
+ }
+ return result;
+}
+#if CYTHON_METH_FASTCALL
+static CYTHON_INLINE int __Pyx_CyFunction_Vectorcall_CheckArgs(__pyx_CyFunctionObject *cyfunc, Py_ssize_t nargs, PyObject *kwnames)
+{
+ int ret = 0;
+ if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) {
+ if (unlikely(nargs < 1)) {
+ PyErr_Format(PyExc_TypeError, "%.200s() needs an argument",
+ ((PyCFunctionObject*)cyfunc)->m_ml->ml_name);
+ return -1;
+ }
+ ret = 1;
+ }
+ if (unlikely(kwnames) && unlikely(PyTuple_GET_SIZE(kwnames))) {
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() takes no keyword arguments", ((PyCFunctionObject*)cyfunc)->m_ml->ml_name);
+ return -1;
+ }
+ return ret;
+}
+static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames)
+{
+ __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func;
+ PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml;
+#if CYTHON_BACKPORT_VECTORCALL
+ Py_ssize_t nargs = (Py_ssize_t)nargsf;
+#else
+ Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
+#endif
+ PyObject *self;
+ switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) {
+ case 1:
+ self = args[0];
+ args += 1;
+ nargs -= 1;
+ break;
+ case 0:
+ self = ((PyCFunctionObject*)cyfunc)->m_self;
+ break;
+ default:
+ return NULL;
+ }
+ if (unlikely(nargs != 0)) {
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ def->ml_name, nargs);
+ return NULL;
+ }
+ return def->ml_meth(self, NULL);
+}
+static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames)
+{
+ __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func;
+ PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml;
+#if CYTHON_BACKPORT_VECTORCALL
+ Py_ssize_t nargs = (Py_ssize_t)nargsf;
+#else
+ Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
+#endif
+ PyObject *self;
+ switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) {
+ case 1:
+ self = args[0];
+ args += 1;
+ nargs -= 1;
+ break;
+ case 0:
+ self = ((PyCFunctionObject*)cyfunc)->m_self;
+ break;
+ default:
+ return NULL;
+ }
+ if (unlikely(nargs != 1)) {
+ PyErr_Format(PyExc_TypeError,
+ "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)",
+ def->ml_name, nargs);
+ return NULL;
+ }
+ return def->ml_meth(self, args[0]);
+}
+static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames)
+{
+ __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func;
+ PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml;
+#if CYTHON_BACKPORT_VECTORCALL
+ Py_ssize_t nargs = (Py_ssize_t)nargsf;
+#else
+ Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
+#endif
+ PyObject *self;
+ switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) {
+ case 1:
+ self = args[0];
+ args += 1;
+ nargs -= 1;
+ break;
+ case 0:
+ self = ((PyCFunctionObject*)cyfunc)->m_self;
+ break;
+ default:
+ return NULL;
+ }
+ return ((__Pyx_PyCFunctionFastWithKeywords)(void(*)(void))def->ml_meth)(self, args, nargs, kwnames);
+}
+static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames)
+{
+ __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func;
+ PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml;
+ PyTypeObject *cls = (PyTypeObject *) __Pyx_CyFunction_GetClassObj(cyfunc);
+#if CYTHON_BACKPORT_VECTORCALL
+ Py_ssize_t nargs = (Py_ssize_t)nargsf;
+#else
+ Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
+#endif
+ PyObject *self;
+ switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) {
+ case 1:
+ self = args[0];
+ args += 1;
+ nargs -= 1;
+ break;
+ case 0:
+ self = ((PyCFunctionObject*)cyfunc)->m_self;
+ break;
+ default:
+ return NULL;
+ }
+ return ((__Pyx_PyCMethod)(void(*)(void))def->ml_meth)(self, cls, args, (size_t)nargs, kwnames);
+}
+#endif
+#if CYTHON_USE_TYPE_SPECS
+static PyType_Slot __pyx_CyFunctionType_slots[] = {
+ {Py_tp_dealloc, (void *)__Pyx_CyFunction_dealloc},
+ {Py_tp_repr, (void *)__Pyx_CyFunction_repr},
+ {Py_tp_call, (void *)__Pyx_CyFunction_CallAsMethod},
+ {Py_tp_traverse, (void *)__Pyx_CyFunction_traverse},
+ {Py_tp_clear, (void *)__Pyx_CyFunction_clear},
+ {Py_tp_methods, (void *)__pyx_CyFunction_methods},
+ {Py_tp_members, (void *)__pyx_CyFunction_members},
+ {Py_tp_getset, (void *)__pyx_CyFunction_getsets},
+ {Py_tp_descr_get, (void *)__Pyx_PyMethod_New},
+ {0, 0},
+};
+static PyType_Spec __pyx_CyFunctionType_spec = {
+ __PYX_TYPE_MODULE_PREFIX "cython_function_or_method",
+ sizeof(__pyx_CyFunctionObject),
+ 0,
+#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR
+ Py_TPFLAGS_METHOD_DESCRIPTOR |
+#endif
+#if (defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL)
+ _Py_TPFLAGS_HAVE_VECTORCALL |
+#endif
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE,
+ __pyx_CyFunctionType_slots
+};
+#else
+static PyTypeObject __pyx_CyFunctionType_type = {
+ PyVarObject_HEAD_INIT(0, 0)
+ __PYX_TYPE_MODULE_PREFIX "cython_function_or_method",
+ sizeof(__pyx_CyFunctionObject),
+ 0,
+ (destructor) __Pyx_CyFunction_dealloc,
+#if !CYTHON_METH_FASTCALL
+ 0,
+#elif CYTHON_BACKPORT_VECTORCALL
+ (printfunc)offsetof(__pyx_CyFunctionObject, func_vectorcall),
+#else
+ offsetof(PyCFunctionObject, vectorcall),
+#endif
+ 0,
+ 0,
+#if PY_MAJOR_VERSION < 3
+ 0,
+#else
+ 0,
+#endif
+ (reprfunc) __Pyx_CyFunction_repr,
+ 0,
+ 0,
+ 0,
+ 0,
+ __Pyx_CyFunction_CallAsMethod,
+ 0,
+ 0,
+ 0,
+ 0,
+#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR
+ Py_TPFLAGS_METHOD_DESCRIPTOR |
+#endif
+#if defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL
+ _Py_TPFLAGS_HAVE_VECTORCALL |
+#endif
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE,
+ 0,
+ (traverseproc) __Pyx_CyFunction_traverse,
+ (inquiry) __Pyx_CyFunction_clear,
+ 0,
+#if PY_VERSION_HEX < 0x030500A0
+ offsetof(__pyx_CyFunctionObject, func_weakreflist),
+#else
+ offsetof(PyCFunctionObject, m_weakreflist),
+#endif
+ 0,
+ 0,
+ __pyx_CyFunction_methods,
+ __pyx_CyFunction_members,
+ __pyx_CyFunction_getsets,
+ 0,
+ 0,
+ __Pyx_PyMethod_New,
+ 0,
+ offsetof(__pyx_CyFunctionObject, func_dict),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+#if PY_VERSION_HEX >= 0x030400a1
+ 0,
+#endif
+#if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800)
+ 0,
+#endif
+#if __PYX_NEED_TP_PRINT_SLOT
+ 0,
+#endif
+#if PY_VERSION_HEX >= 0x030C0000
+ 0,
+#endif
+#if PY_VERSION_HEX >= 0x030d00A4
+ 0,
+#endif
+#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000
+ 0,
+#endif
+};
+#endif
+static int __pyx_CyFunction_init(PyObject *module) {
+#if CYTHON_USE_TYPE_SPECS
+ __pyx_CyFunctionType = __Pyx_FetchCommonTypeFromSpec(module, &__pyx_CyFunctionType_spec, NULL);
+#else
+ CYTHON_UNUSED_VAR(module);
+ __pyx_CyFunctionType = __Pyx_FetchCommonType(&__pyx_CyFunctionType_type);
+#endif
+ if (unlikely(__pyx_CyFunctionType == NULL)) {
+ return -1;
+ }
+ return 0;
+}
+static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) {
+ __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
+ m->defaults = PyObject_Malloc(size);
+ if (unlikely(!m->defaults))
+ return PyErr_NoMemory();
+ memset(m->defaults, 0, size);
+ m->defaults_pyobjects = pyobjects;
+ m->defaults_size = size;
+ return m->defaults;
+}
+static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple) {
+ __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
+ m->defaults_tuple = tuple;
+ Py_INCREF(tuple);
+}
+static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *func, PyObject *dict) {
+ __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
+ m->defaults_kwdict = dict;
+ Py_INCREF(dict);
+}
+static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *func, PyObject *dict) {
+ __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
+ m->func_annotations = dict;
+ Py_INCREF(dict);
+}
+
+/* CythonFunction */
+ static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, int flags, PyObject* qualname,
+ PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) {
+ PyObject *op = __Pyx_CyFunction_Init(
+ PyObject_GC_New(__pyx_CyFunctionObject, __pyx_CyFunctionType),
+ ml, flags, qualname, closure, module, globals, code
+ );
+ if (likely(op)) {
+ PyObject_GC_Track(op);
+ }
+ return op;
+}
+
+/* CLineInTraceback */
+ #ifndef CYTHON_CLINE_IN_TRACEBACK
+static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line) {
+ PyObject *use_cline;
+ PyObject *ptype, *pvalue, *ptraceback;
+#if CYTHON_COMPILING_IN_CPYTHON
+ PyObject **cython_runtime_dict;
+#endif
+ CYTHON_MAYBE_UNUSED_VAR(tstate);
+ if (unlikely(!__pyx_cython_runtime)) {
+ return c_line;
+ }
+ __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback);
+#if CYTHON_COMPILING_IN_CPYTHON
+ cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime);
+ if (likely(cython_runtime_dict)) {
+ __PYX_PY_DICT_LOOKUP_IF_MODIFIED(
+ use_cline, *cython_runtime_dict,
+ __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback))
+ } else
+#endif
+ {
+ PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStrNoError(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback);
+ if (use_cline_obj) {
+ use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True;
+ Py_DECREF(use_cline_obj);
+ } else {
+ PyErr_Clear();
+ use_cline = NULL;
+ }
+ }
+ if (!use_cline) {
+ c_line = 0;
+ (void) PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False);
+ }
+ else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) {
+ c_line = 0;
+ }
+ __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback);
+ return c_line;
+}
+#endif
+
+/* CodeObjectCache */
+ #if !CYTHON_COMPILING_IN_LIMITED_API
+static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) {
+ int start = 0, mid = 0, end = count - 1;
+ if (end >= 0 && code_line > entries[end].code_line) {
+ return count;
+ }
+ while (start < end) {
+ mid = start + (end - start) / 2;
+ if (code_line < entries[mid].code_line) {
+ end = mid;
+ } else if (code_line > entries[mid].code_line) {
+ start = mid + 1;
+ } else {
+ return mid;
+ }
+ }
+ if (code_line <= entries[mid].code_line) {
+ return mid;
+ } else {
+ return mid + 1;
+ }
+}
+static PyCodeObject *__pyx_find_code_object(int code_line) {
+ PyCodeObject* code_object;
+ int pos;
+ if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) {
+ return NULL;
+ }
+ pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line);
+ if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) {
+ return NULL;
+ }
+ code_object = __pyx_code_cache.entries[pos].code_object;
+ Py_INCREF(code_object);
+ return code_object;
+}
+static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) {
+ int pos, i;
+ __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries;
+ if (unlikely(!code_line)) {
+ return;
+ }
+ if (unlikely(!entries)) {
+ entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry));
+ if (likely(entries)) {
+ __pyx_code_cache.entries = entries;
+ __pyx_code_cache.max_count = 64;
+ __pyx_code_cache.count = 1;
+ entries[0].code_line = code_line;
+ entries[0].code_object = code_object;
+ Py_INCREF(code_object);
+ }
+ return;
+ }
+ pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line);
+ if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) {
+ PyCodeObject* tmp = entries[pos].code_object;
+ entries[pos].code_object = code_object;
+ Py_DECREF(tmp);
+ return;
+ }
+ if (__pyx_code_cache.count == __pyx_code_cache.max_count) {
+ int new_max = __pyx_code_cache.max_count + 64;
+ entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc(
+ __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry));
+ if (unlikely(!entries)) {
+ return;
+ }
+ __pyx_code_cache.entries = entries;
+ __pyx_code_cache.max_count = new_max;
+ }
+ for (i=__pyx_code_cache.count; i>pos; i--) {
+ entries[i] = entries[i-1];
+ }
+ entries[pos].code_line = code_line;
+ entries[pos].code_object = code_object;
+ __pyx_code_cache.count++;
+ Py_INCREF(code_object);
+}
+#endif
+
+/* AddTraceback */
+ #include "compile.h"
+#include "frameobject.h"
+#include "traceback.h"
+#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API && !defined(PYPY_VERSION)
+ #ifndef Py_BUILD_CORE
+ #define Py_BUILD_CORE 1
+ #endif
+ #include "internal/pycore_frame.h"
+#endif
+#if CYTHON_COMPILING_IN_LIMITED_API
+static PyObject *__Pyx_PyCode_Replace_For_AddTraceback(PyObject *code, PyObject *scratch_dict,
+ PyObject *firstlineno, PyObject *name) {
+ PyObject *replace = NULL;
+ if (unlikely(PyDict_SetItemString(scratch_dict, "co_firstlineno", firstlineno))) return NULL;
+ if (unlikely(PyDict_SetItemString(scratch_dict, "co_name", name))) return NULL;
+ replace = PyObject_GetAttrString(code, "replace");
+ if (likely(replace)) {
+ PyObject *result;
+ result = PyObject_Call(replace, __pyx_empty_tuple, scratch_dict);
+ Py_DECREF(replace);
+ return result;
+ }
+ PyErr_Clear();
+ #if __PYX_LIMITED_VERSION_HEX < 0x030780000
+ {
+ PyObject *compiled = NULL, *result = NULL;
+ if (unlikely(PyDict_SetItemString(scratch_dict, "code", code))) return NULL;
+ if (unlikely(PyDict_SetItemString(scratch_dict, "type", (PyObject*)(&PyType_Type)))) return NULL;
+ compiled = Py_CompileString(
+ "out = type(code)(\n"
+ " code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize,\n"
+ " code.co_flags, code.co_code, code.co_consts, code.co_names,\n"
+ " code.co_varnames, code.co_filename, co_name, co_firstlineno,\n"
+ " code.co_lnotab)\n", "", Py_file_input);
+ if (!compiled) return NULL;
+ result = PyEval_EvalCode(compiled, scratch_dict, scratch_dict);
+ Py_DECREF(compiled);
+ if (!result) PyErr_Print();
+ Py_DECREF(result);
+ result = PyDict_GetItemString(scratch_dict, "out");
+ if (result) Py_INCREF(result);
+ return result;
+ }
+ #else
+ return NULL;
+ #endif
+}
+static void __Pyx_AddTraceback(const char *funcname, int c_line,
+ int py_line, const char *filename) {
+ PyObject *code_object = NULL, *py_py_line = NULL, *py_funcname = NULL, *dict = NULL;
+ PyObject *replace = NULL, *getframe = NULL, *frame = NULL;
+ PyObject *exc_type, *exc_value, *exc_traceback;
+ int success = 0;
+ if (c_line) {
+ (void) __pyx_cfilenm;
+ (void) __Pyx_CLineForTraceback(__Pyx_PyThreadState_Current, c_line);
+ }
+ PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
+ code_object = Py_CompileString("_getframe()", filename, Py_eval_input);
+ if (unlikely(!code_object)) goto bad;
+ py_py_line = PyLong_FromLong(py_line);
+ if (unlikely(!py_py_line)) goto bad;
+ py_funcname = PyUnicode_FromString(funcname);
+ if (unlikely(!py_funcname)) goto bad;
+ dict = PyDict_New();
+ if (unlikely(!dict)) goto bad;
+ {
+ PyObject *old_code_object = code_object;
+ code_object = __Pyx_PyCode_Replace_For_AddTraceback(code_object, dict, py_py_line, py_funcname);
+ Py_DECREF(old_code_object);
+ }
+ if (unlikely(!code_object)) goto bad;
+ getframe = PySys_GetObject("_getframe");
+ if (unlikely(!getframe)) goto bad;
+ if (unlikely(PyDict_SetItemString(dict, "_getframe", getframe))) goto bad;
+ frame = PyEval_EvalCode(code_object, dict, dict);
+ if (unlikely(!frame) || frame == Py_None) goto bad;
+ success = 1;
+ bad:
+ PyErr_Restore(exc_type, exc_value, exc_traceback);
+ Py_XDECREF(code_object);
+ Py_XDECREF(py_py_line);
+ Py_XDECREF(py_funcname);
+ Py_XDECREF(dict);
+ Py_XDECREF(replace);
+ if (success) {
+ PyTraceBack_Here(
+ (struct _frame*)frame);
+ }
+ Py_XDECREF(frame);
+}
+#else
+static PyCodeObject* __Pyx_CreateCodeObjectForTraceback(
+ const char *funcname, int c_line,
+ int py_line, const char *filename) {
+ PyCodeObject *py_code = NULL;
+ PyObject *py_funcname = NULL;
+ #if PY_MAJOR_VERSION < 3
+ PyObject *py_srcfile = NULL;
+ py_srcfile = PyString_FromString(filename);
+ if (!py_srcfile) goto bad;
+ #endif
+ if (c_line) {
+ #if PY_MAJOR_VERSION < 3
+ py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line);
+ if (!py_funcname) goto bad;
+ #else
+ py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line);
+ if (!py_funcname) goto bad;
+ funcname = PyUnicode_AsUTF8(py_funcname);
+ if (!funcname) goto bad;
+ #endif
+ }
+ else {
+ #if PY_MAJOR_VERSION < 3
+ py_funcname = PyString_FromString(funcname);
+ if (!py_funcname) goto bad;
+ #endif
+ }
+ #if PY_MAJOR_VERSION < 3
+ py_code = __Pyx_PyCode_New(
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ __pyx_empty_bytes, /*PyObject *code,*/
+ __pyx_empty_tuple, /*PyObject *consts,*/
+ __pyx_empty_tuple, /*PyObject *names,*/
+ __pyx_empty_tuple, /*PyObject *varnames,*/
+ __pyx_empty_tuple, /*PyObject *freevars,*/
+ __pyx_empty_tuple, /*PyObject *cellvars,*/
+ py_srcfile, /*PyObject *filename,*/
+ py_funcname, /*PyObject *name,*/
+ py_line,
+ __pyx_empty_bytes /*PyObject *lnotab*/
+ );
+ Py_DECREF(py_srcfile);
+ #else
+ py_code = PyCode_NewEmpty(filename, funcname, py_line);
+ #endif
+ Py_XDECREF(py_funcname);
+ return py_code;
+bad:
+ Py_XDECREF(py_funcname);
+ #if PY_MAJOR_VERSION < 3
+ Py_XDECREF(py_srcfile);
+ #endif
+ return NULL;
+}
+static void __Pyx_AddTraceback(const char *funcname, int c_line,
+ int py_line, const char *filename) {
+ PyCodeObject *py_code = 0;
+ PyFrameObject *py_frame = 0;
+ PyThreadState *tstate = __Pyx_PyThreadState_Current;
+ PyObject *ptype, *pvalue, *ptraceback;
+ if (c_line) {
+ c_line = __Pyx_CLineForTraceback(tstate, c_line);
+ }
+ py_code = __pyx_find_code_object(c_line ? -c_line : py_line);
+ if (!py_code) {
+ __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback);
+ py_code = __Pyx_CreateCodeObjectForTraceback(
+ funcname, c_line, py_line, filename);
+ if (!py_code) {
+ /* If the code object creation fails, then we should clear the
+ fetched exception references and propagate the new exception */
+ Py_XDECREF(ptype);
+ Py_XDECREF(pvalue);
+ Py_XDECREF(ptraceback);
+ goto bad;
+ }
+ __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback);
+ __pyx_insert_code_object(c_line ? -c_line : py_line, py_code);
+ }
+ py_frame = PyFrame_New(
+ tstate, /*PyThreadState *tstate,*/
+ py_code, /*PyCodeObject *code,*/
+ __pyx_d, /*PyObject *globals,*/
+ 0 /*PyObject *locals*/
+ );
+ if (!py_frame) goto bad;
+ __Pyx_PyFrame_SetLineNumber(py_frame, py_line);
+ PyTraceBack_Here(py_frame);
+bad:
+ Py_XDECREF(py_code);
+ Py_XDECREF(py_frame);
+}
+#endif
+
+#if PY_MAJOR_VERSION < 3
+static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
+ __Pyx_TypeName obj_type_name;
+ if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags);
+ obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj));
+ PyErr_Format(PyExc_TypeError,
+ "'" __Pyx_FMT_TYPENAME "' does not have the buffer interface",
+ obj_type_name);
+ __Pyx_DECREF_TypeName(obj_type_name);
+ return -1;
+}
+static void __Pyx_ReleaseBuffer(Py_buffer *view) {
+ PyObject *obj = view->obj;
+ if (!obj) return;
+ if (PyObject_CheckBuffer(obj)) {
+ PyBuffer_Release(view);
+ return;
+ }
+ if ((0)) {}
+ view->obj = NULL;
+ Py_DECREF(obj);
+}
+#endif
+
+
+ /* CIntFromPyVerify */
+ #define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\
+ __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0)
+#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\
+ __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1)
+#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\
+ {\
+ func_type value = func_value;\
+ if (sizeof(target_type) < sizeof(func_type)) {\
+ if (unlikely(value != (func_type) (target_type) value)) {\
+ func_type zero = 0;\
+ if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\
+ return (target_type) -1;\
+ if (is_unsigned && unlikely(value < zero))\
+ goto raise_neg_overflow;\
+ else\
+ goto raise_overflow;\
+ }\
+ }\
+ return (target_type) value;\
+ }
+
+/* Declarations */
+ #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #ifdef __cplusplus
+ static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) {
+ return ::std::complex< float >(x, y);
+ }
+ #else
+ static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) {
+ return x + y*(__pyx_t_float_complex)_Complex_I;
+ }
+ #endif
+#else
+ static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) {
+ __pyx_t_float_complex z;
+ z.real = x;
+ z.imag = y;
+ return z;
+ }
+#endif
+
+/* Arithmetic */
+ #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ return (a.real == b.real) && (a.imag == b.imag);
+ }
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ __pyx_t_float_complex z;
+ z.real = a.real + b.real;
+ z.imag = a.imag + b.imag;
+ return z;
+ }
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ __pyx_t_float_complex z;
+ z.real = a.real - b.real;
+ z.imag = a.imag - b.imag;
+ return z;
+ }
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ __pyx_t_float_complex z;
+ z.real = a.real * b.real - a.imag * b.imag;
+ z.imag = a.real * b.imag + a.imag * b.real;
+ return z;
+ }
+ #if 1
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ if (b.imag == 0) {
+ return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real);
+ } else if (fabsf(b.real) >= fabsf(b.imag)) {
+ if (b.real == 0 && b.imag == 0) {
+ return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.imag);
+ } else {
+ float r = b.imag / b.real;
+ float s = (float)(1.0) / (b.real + b.imag * r);
+ return __pyx_t_float_complex_from_parts(
+ (a.real + a.imag * r) * s, (a.imag - a.real * r) * s);
+ }
+ } else {
+ float r = b.real / b.imag;
+ float s = (float)(1.0) / (b.imag + b.real * r);
+ return __pyx_t_float_complex_from_parts(
+ (a.real * r + a.imag) * s, (a.imag * r - a.real) * s);
+ }
+ }
+ #else
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ if (b.imag == 0) {
+ return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real);
+ } else {
+ float denom = b.real * b.real + b.imag * b.imag;
+ return __pyx_t_float_complex_from_parts(
+ (a.real * b.real + a.imag * b.imag) / denom,
+ (a.imag * b.real - a.real * b.imag) / denom);
+ }
+ }
+ #endif
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex a) {
+ __pyx_t_float_complex z;
+ z.real = -a.real;
+ z.imag = -a.imag;
+ return z;
+ }
+ static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex a) {
+ return (a.real == 0) && (a.imag == 0);
+ }
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex a) {
+ __pyx_t_float_complex z;
+ z.real = a.real;
+ z.imag = -a.imag;
+ return z;
+ }
+ #if 1
+ static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex z) {
+ #if !defined(HAVE_HYPOT) || defined(_MSC_VER)
+ return sqrtf(z.real*z.real + z.imag*z.imag);
+ #else
+ return hypotf(z.real, z.imag);
+ #endif
+ }
+ static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex a, __pyx_t_float_complex b) {
+ __pyx_t_float_complex z;
+ float r, lnr, theta, z_r, z_theta;
+ if (b.imag == 0 && b.real == (int)b.real) {
+ if (b.real < 0) {
+ float denom = a.real * a.real + a.imag * a.imag;
+ a.real = a.real / denom;
+ a.imag = -a.imag / denom;
+ b.real = -b.real;
+ }
+ switch ((int)b.real) {
+ case 0:
+ z.real = 1;
+ z.imag = 0;
+ return z;
+ case 1:
+ return a;
+ case 2:
+ return __Pyx_c_prod_float(a, a);
+ case 3:
+ z = __Pyx_c_prod_float(a, a);
+ return __Pyx_c_prod_float(z, a);
+ case 4:
+ z = __Pyx_c_prod_float(a, a);
+ return __Pyx_c_prod_float(z, z);
+ }
+ }
+ if (a.imag == 0) {
+ if (a.real == 0) {
+ return a;
+ } else if ((b.imag == 0) && (a.real >= 0)) {
+ z.real = powf(a.real, b.real);
+ z.imag = 0;
+ return z;
+ } else if (a.real > 0) {
+ r = a.real;
+ theta = 0;
+ } else {
+ r = -a.real;
+ theta = atan2f(0.0, -1.0);
+ }
+ } else {
+ r = __Pyx_c_abs_float(a);
+ theta = atan2f(a.imag, a.real);
+ }
+ lnr = logf(r);
+ z_r = expf(lnr * b.real - theta * b.imag);
+ z_theta = theta * b.real + lnr * b.imag;
+ z.real = z_r * cosf(z_theta);
+ z.imag = z_r * sinf(z_theta);
+ return z;
+ }
+ #endif
+#endif
+
+/* Declarations */
+ #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+ #ifdef __cplusplus
+ static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) {
+ return ::std::complex< double >(x, y);
+ }
+ #else
+ static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) {
+ return x + y*(__pyx_t_double_complex)_Complex_I;
+ }
+ #endif
+#else
+ static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) {
+ __pyx_t_double_complex z;
+ z.real = x;
+ z.imag = y;
+ return z;
+ }
+#endif
+
+/* Arithmetic */
+ #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus)
+#else
+ static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ return (a.real == b.real) && (a.imag == b.imag);
+ }
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ __pyx_t_double_complex z;
+ z.real = a.real + b.real;
+ z.imag = a.imag + b.imag;
+ return z;
+ }
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ __pyx_t_double_complex z;
+ z.real = a.real - b.real;
+ z.imag = a.imag - b.imag;
+ return z;
+ }
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ __pyx_t_double_complex z;
+ z.real = a.real * b.real - a.imag * b.imag;
+ z.imag = a.real * b.imag + a.imag * b.real;
+ return z;
+ }
+ #if 1
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ if (b.imag == 0) {
+ return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real);
+ } else if (fabs(b.real) >= fabs(b.imag)) {
+ if (b.real == 0 && b.imag == 0) {
+ return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.imag);
+ } else {
+ double r = b.imag / b.real;
+ double s = (double)(1.0) / (b.real + b.imag * r);
+ return __pyx_t_double_complex_from_parts(
+ (a.real + a.imag * r) * s, (a.imag - a.real * r) * s);
+ }
+ } else {
+ double r = b.real / b.imag;
+ double s = (double)(1.0) / (b.imag + b.real * r);
+ return __pyx_t_double_complex_from_parts(
+ (a.real * r + a.imag) * s, (a.imag * r - a.real) * s);
+ }
+ }
+ #else
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ if (b.imag == 0) {
+ return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real);
+ } else {
+ double denom = b.real * b.real + b.imag * b.imag;
+ return __pyx_t_double_complex_from_parts(
+ (a.real * b.real + a.imag * b.imag) / denom,
+ (a.imag * b.real - a.real * b.imag) / denom);
+ }
+ }
+ #endif
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex a) {
+ __pyx_t_double_complex z;
+ z.real = -a.real;
+ z.imag = -a.imag;
+ return z;
+ }
+ static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex a) {
+ return (a.real == 0) && (a.imag == 0);
+ }
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex a) {
+ __pyx_t_double_complex z;
+ z.real = a.real;
+ z.imag = -a.imag;
+ return z;
+ }
+ #if 1
+ static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex z) {
+ #if !defined(HAVE_HYPOT) || defined(_MSC_VER)
+ return sqrt(z.real*z.real + z.imag*z.imag);
+ #else
+ return hypot(z.real, z.imag);
+ #endif
+ }
+ static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex a, __pyx_t_double_complex b) {
+ __pyx_t_double_complex z;
+ double r, lnr, theta, z_r, z_theta;
+ if (b.imag == 0 && b.real == (int)b.real) {
+ if (b.real < 0) {
+ double denom = a.real * a.real + a.imag * a.imag;
+ a.real = a.real / denom;
+ a.imag = -a.imag / denom;
+ b.real = -b.real;
+ }
+ switch ((int)b.real) {
+ case 0:
+ z.real = 1;
+ z.imag = 0;
+ return z;
+ case 1:
+ return a;
+ case 2:
+ return __Pyx_c_prod_double(a, a);
+ case 3:
+ z = __Pyx_c_prod_double(a, a);
+ return __Pyx_c_prod_double(z, a);
+ case 4:
+ z = __Pyx_c_prod_double(a, a);
+ return __Pyx_c_prod_double(z, z);
+ }
+ }
+ if (a.imag == 0) {
+ if (a.real == 0) {
+ return a;
+ } else if ((b.imag == 0) && (a.real >= 0)) {
+ z.real = pow(a.real, b.real);
+ z.imag = 0;
+ return z;
+ } else if (a.real > 0) {
+ r = a.real;
+ theta = 0;
+ } else {
+ r = -a.real;
+ theta = atan2(0.0, -1.0);
+ }
+ } else {
+ r = __Pyx_c_abs_double(a);
+ theta = atan2(a.imag, a.real);
+ }
+ lnr = log(r);
+ z_r = exp(lnr * b.real - theta * b.imag);
+ z_theta = theta * b.real + lnr * b.imag;
+ z.real = z_r * cos(z_theta);
+ z.imag = z_r * sin(z_theta);
+ return z;
+ }
+ #endif
+#endif
+
+/* CIntFromPy */
+ static CYTHON_INLINE unsigned int __Pyx_PyInt_As_unsigned_int(PyObject *x) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const unsigned int neg_one = (unsigned int) -1, const_zero = (unsigned int) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+#if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_Check(x))) {
+ if ((sizeof(unsigned int) < sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, long, PyInt_AS_LONG(x))
+ } else {
+ long val = PyInt_AS_LONG(x);
+ if (is_unsigned && unlikely(val < 0)) {
+ goto raise_neg_overflow;
+ }
+ return (unsigned int) val;
+ }
+ }
+#endif
+ if (unlikely(!PyLong_Check(x))) {
+ unsigned int val;
+ PyObject *tmp = __Pyx_PyNumber_IntOrLong(x);
+ if (!tmp) return (unsigned int) -1;
+ val = __Pyx_PyInt_As_unsigned_int(tmp);
+ Py_DECREF(tmp);
+ return val;
+ }
+ if (is_unsigned) {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (unlikely(__Pyx_PyLong_IsNeg(x))) {
+ goto raise_neg_overflow;
+ } else if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_DigitCount(x)) {
+ case 2:
+ if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) >= 2 * PyLong_SHIFT)) {
+ return (unsigned int) (((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) >= 3 * PyLong_SHIFT)) {
+ return (unsigned int) (((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) >= 4 * PyLong_SHIFT)) {
+ return (unsigned int) (((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]));
+ }
+ }
+ break;
+ }
+ }
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7
+ if (unlikely(Py_SIZE(x) < 0)) {
+ goto raise_neg_overflow;
+ }
+#else
+ {
+ int result = PyObject_RichCompareBool(x, Py_False, Py_LT);
+ if (unlikely(result < 0))
+ return (unsigned int) -1;
+ if (unlikely(result == 1))
+ goto raise_neg_overflow;
+ }
+#endif
+ if ((sizeof(unsigned int) <= sizeof(unsigned long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned long, PyLong_AsUnsignedLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(unsigned int) <= sizeof(unsigned PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x))
+#endif
+ }
+ } else {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_SignedDigitCount(x)) {
+ case -2:
+ if ((8 * sizeof(unsigned int) - 1 > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) {
+ return (unsigned int) (((unsigned int)-1)*(((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ case 2:
+ if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) {
+ return (unsigned int) ((((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ case -3:
+ if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) {
+ return (unsigned int) (((unsigned int)-1)*(((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) {
+ return (unsigned int) ((((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ case -4:
+ if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) {
+ return (unsigned int) (((unsigned int)-1)*(((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) {
+ return (unsigned int) ((((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])));
+ }
+ }
+ break;
+ }
+ }
+#endif
+ if ((sizeof(unsigned int) <= sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(unsigned int, long, PyLong_AsLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(unsigned int) <= sizeof(PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(unsigned int, PY_LONG_LONG, PyLong_AsLongLong(x))
+#endif
+ }
+ }
+ {
+ unsigned int val;
+ int ret = -1;
+#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API
+ Py_ssize_t bytes_copied = PyLong_AsNativeBytes(
+ x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0));
+ if (unlikely(bytes_copied == -1)) {
+ } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) {
+ goto raise_overflow;
+ } else {
+ ret = 0;
+ }
+#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray)
+ int one = 1; int is_little = (int)*(unsigned char *)&one;
+ unsigned char *bytes = (unsigned char *)&val;
+ ret = _PyLong_AsByteArray((PyLongObject *)x,
+ bytes, sizeof(val),
+ is_little, !is_unsigned);
+#else
+ PyObject *v;
+ PyObject *stepval = NULL, *mask = NULL, *shift = NULL;
+ int bits, remaining_bits, is_negative = 0;
+ int chunk_size = (sizeof(long) < 8) ? 30 : 62;
+ if (likely(PyLong_CheckExact(x))) {
+ v = __Pyx_NewRef(x);
+ } else {
+ v = PyNumber_Long(x);
+ if (unlikely(!v)) return (unsigned int) -1;
+ assert(PyLong_CheckExact(v));
+ }
+ {
+ int result = PyObject_RichCompareBool(v, Py_False, Py_LT);
+ if (unlikely(result < 0)) {
+ Py_DECREF(v);
+ return (unsigned int) -1;
+ }
+ is_negative = result == 1;
+ }
+ if (is_unsigned && unlikely(is_negative)) {
+ Py_DECREF(v);
+ goto raise_neg_overflow;
+ } else if (is_negative) {
+ stepval = PyNumber_Invert(v);
+ Py_DECREF(v);
+ if (unlikely(!stepval))
+ return (unsigned int) -1;
+ } else {
+ stepval = v;
+ }
+ v = NULL;
+ val = (unsigned int) 0;
+ mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done;
+ shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done;
+ for (bits = 0; bits < (int) sizeof(unsigned int) * 8 - chunk_size; bits += chunk_size) {
+ PyObject *tmp, *digit;
+ long idigit;
+ digit = PyNumber_And(stepval, mask);
+ if (unlikely(!digit)) goto done;
+ idigit = PyLong_AsLong(digit);
+ Py_DECREF(digit);
+ if (unlikely(idigit < 0)) goto done;
+ val |= ((unsigned int) idigit) << bits;
+ tmp = PyNumber_Rshift(stepval, shift);
+ if (unlikely(!tmp)) goto done;
+ Py_DECREF(stepval); stepval = tmp;
+ }
+ Py_DECREF(shift); shift = NULL;
+ Py_DECREF(mask); mask = NULL;
+ {
+ long idigit = PyLong_AsLong(stepval);
+ if (unlikely(idigit < 0)) goto done;
+ remaining_bits = ((int) sizeof(unsigned int) * 8) - bits - (is_unsigned ? 0 : 1);
+ if (unlikely(idigit >= (1L << remaining_bits)))
+ goto raise_overflow;
+ val |= ((unsigned int) idigit) << bits;
+ }
+ if (!is_unsigned) {
+ if (unlikely(val & (((unsigned int) 1) << (sizeof(unsigned int) * 8 - 1))))
+ goto raise_overflow;
+ if (is_negative)
+ val = ~val;
+ }
+ ret = 0;
+ done:
+ Py_XDECREF(shift);
+ Py_XDECREF(mask);
+ Py_XDECREF(stepval);
+#endif
+ if (unlikely(ret))
+ return (unsigned int) -1;
+ return val;
+ }
+raise_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "value too large to convert to unsigned int");
+ return (unsigned int) -1;
+raise_neg_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "can't convert negative value to unsigned int");
+ return (unsigned int) -1;
+}
+
+/* CIntToPy */
+ static CYTHON_INLINE PyObject* __Pyx_PyInt_From_unsigned_int(unsigned int value) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const unsigned int neg_one = (unsigned int) -1, const_zero = (unsigned int) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+ if (is_unsigned) {
+ if (sizeof(unsigned int) < sizeof(long)) {
+ return PyInt_FromLong((long) value);
+ } else if (sizeof(unsigned int) <= sizeof(unsigned long)) {
+ return PyLong_FromUnsignedLong((unsigned long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(unsigned int) <= sizeof(unsigned PY_LONG_LONG)) {
+ return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value);
+#endif
+ }
+ } else {
+ if (sizeof(unsigned int) <= sizeof(long)) {
+ return PyInt_FromLong((long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(unsigned int) <= sizeof(PY_LONG_LONG)) {
+ return PyLong_FromLongLong((PY_LONG_LONG) value);
+#endif
+ }
+ }
+ {
+ unsigned char *bytes = (unsigned char *)&value;
+#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4
+ if (is_unsigned) {
+ return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1);
+ } else {
+ return PyLong_FromNativeBytes(bytes, sizeof(value), -1);
+ }
+#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ return _PyLong_FromByteArray(bytes, sizeof(unsigned int),
+ little, !is_unsigned);
+#else
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ PyObject *from_bytes, *result = NULL;
+ PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL;
+ from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes");
+ if (!from_bytes) return NULL;
+ py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(unsigned int));
+ if (!py_bytes) goto limited_bad;
+ order_str = PyUnicode_FromString(little ? "little" : "big");
+ if (!order_str) goto limited_bad;
+ arg_tuple = PyTuple_Pack(2, py_bytes, order_str);
+ if (!arg_tuple) goto limited_bad;
+ if (!is_unsigned) {
+ kwds = PyDict_New();
+ if (!kwds) goto limited_bad;
+ if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad;
+ }
+ result = PyObject_Call(from_bytes, arg_tuple, kwds);
+ limited_bad:
+ Py_XDECREF(kwds);
+ Py_XDECREF(arg_tuple);
+ Py_XDECREF(order_str);
+ Py_XDECREF(py_bytes);
+ Py_XDECREF(from_bytes);
+ return result;
+#endif
+ }
+}
+
+/* CIntToPy */
+ static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const int neg_one = (int) -1, const_zero = (int) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+ if (is_unsigned) {
+ if (sizeof(int) < sizeof(long)) {
+ return PyInt_FromLong((long) value);
+ } else if (sizeof(int) <= sizeof(unsigned long)) {
+ return PyLong_FromUnsignedLong((unsigned long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) {
+ return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value);
+#endif
+ }
+ } else {
+ if (sizeof(int) <= sizeof(long)) {
+ return PyInt_FromLong((long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) {
+ return PyLong_FromLongLong((PY_LONG_LONG) value);
+#endif
+ }
+ }
+ {
+ unsigned char *bytes = (unsigned char *)&value;
+#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4
+ if (is_unsigned) {
+ return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1);
+ } else {
+ return PyLong_FromNativeBytes(bytes, sizeof(value), -1);
+ }
+#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ return _PyLong_FromByteArray(bytes, sizeof(int),
+ little, !is_unsigned);
+#else
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ PyObject *from_bytes, *result = NULL;
+ PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL;
+ from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes");
+ if (!from_bytes) return NULL;
+ py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int));
+ if (!py_bytes) goto limited_bad;
+ order_str = PyUnicode_FromString(little ? "little" : "big");
+ if (!order_str) goto limited_bad;
+ arg_tuple = PyTuple_Pack(2, py_bytes, order_str);
+ if (!arg_tuple) goto limited_bad;
+ if (!is_unsigned) {
+ kwds = PyDict_New();
+ if (!kwds) goto limited_bad;
+ if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad;
+ }
+ result = PyObject_Call(from_bytes, arg_tuple, kwds);
+ limited_bad:
+ Py_XDECREF(kwds);
+ Py_XDECREF(arg_tuple);
+ Py_XDECREF(order_str);
+ Py_XDECREF(py_bytes);
+ Py_XDECREF(from_bytes);
+ return result;
+#endif
+ }
+}
+
+/* CIntFromPy */
+ static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const int neg_one = (int) -1, const_zero = (int) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+#if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_Check(x))) {
+ if ((sizeof(int) < sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x))
+ } else {
+ long val = PyInt_AS_LONG(x);
+ if (is_unsigned && unlikely(val < 0)) {
+ goto raise_neg_overflow;
+ }
+ return (int) val;
+ }
+ }
+#endif
+ if (unlikely(!PyLong_Check(x))) {
+ int val;
+ PyObject *tmp = __Pyx_PyNumber_IntOrLong(x);
+ if (!tmp) return (int) -1;
+ val = __Pyx_PyInt_As_int(tmp);
+ Py_DECREF(tmp);
+ return val;
+ }
+ if (is_unsigned) {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (unlikely(__Pyx_PyLong_IsNeg(x))) {
+ goto raise_neg_overflow;
+ } else if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_DigitCount(x)) {
+ case 2:
+ if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) {
+ return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) {
+ return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) {
+ return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]));
+ }
+ }
+ break;
+ }
+ }
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7
+ if (unlikely(Py_SIZE(x) < 0)) {
+ goto raise_neg_overflow;
+ }
+#else
+ {
+ int result = PyObject_RichCompareBool(x, Py_False, Py_LT);
+ if (unlikely(result < 0))
+ return (int) -1;
+ if (unlikely(result == 1))
+ goto raise_neg_overflow;
+ }
+#endif
+ if ((sizeof(int) <= sizeof(unsigned long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x))
+#endif
+ }
+ } else {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_SignedDigitCount(x)) {
+ case -2:
+ if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) {
+ return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ case 2:
+ if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) {
+ return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ case -3:
+ if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) {
+ return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) {
+ return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ case -4:
+ if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) {
+ return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) {
+ return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])));
+ }
+ }
+ break;
+ }
+ }
+#endif
+ if ((sizeof(int) <= sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x))
+#endif
+ }
+ }
+ {
+ int val;
+ int ret = -1;
+#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API
+ Py_ssize_t bytes_copied = PyLong_AsNativeBytes(
+ x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0));
+ if (unlikely(bytes_copied == -1)) {
+ } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) {
+ goto raise_overflow;
+ } else {
+ ret = 0;
+ }
+#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray)
+ int one = 1; int is_little = (int)*(unsigned char *)&one;
+ unsigned char *bytes = (unsigned char *)&val;
+ ret = _PyLong_AsByteArray((PyLongObject *)x,
+ bytes, sizeof(val),
+ is_little, !is_unsigned);
+#else
+ PyObject *v;
+ PyObject *stepval = NULL, *mask = NULL, *shift = NULL;
+ int bits, remaining_bits, is_negative = 0;
+ int chunk_size = (sizeof(long) < 8) ? 30 : 62;
+ if (likely(PyLong_CheckExact(x))) {
+ v = __Pyx_NewRef(x);
+ } else {
+ v = PyNumber_Long(x);
+ if (unlikely(!v)) return (int) -1;
+ assert(PyLong_CheckExact(v));
+ }
+ {
+ int result = PyObject_RichCompareBool(v, Py_False, Py_LT);
+ if (unlikely(result < 0)) {
+ Py_DECREF(v);
+ return (int) -1;
+ }
+ is_negative = result == 1;
+ }
+ if (is_unsigned && unlikely(is_negative)) {
+ Py_DECREF(v);
+ goto raise_neg_overflow;
+ } else if (is_negative) {
+ stepval = PyNumber_Invert(v);
+ Py_DECREF(v);
+ if (unlikely(!stepval))
+ return (int) -1;
+ } else {
+ stepval = v;
+ }
+ v = NULL;
+ val = (int) 0;
+ mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done;
+ shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done;
+ for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) {
+ PyObject *tmp, *digit;
+ long idigit;
+ digit = PyNumber_And(stepval, mask);
+ if (unlikely(!digit)) goto done;
+ idigit = PyLong_AsLong(digit);
+ Py_DECREF(digit);
+ if (unlikely(idigit < 0)) goto done;
+ val |= ((int) idigit) << bits;
+ tmp = PyNumber_Rshift(stepval, shift);
+ if (unlikely(!tmp)) goto done;
+ Py_DECREF(stepval); stepval = tmp;
+ }
+ Py_DECREF(shift); shift = NULL;
+ Py_DECREF(mask); mask = NULL;
+ {
+ long idigit = PyLong_AsLong(stepval);
+ if (unlikely(idigit < 0)) goto done;
+ remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1);
+ if (unlikely(idigit >= (1L << remaining_bits)))
+ goto raise_overflow;
+ val |= ((int) idigit) << bits;
+ }
+ if (!is_unsigned) {
+ if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1))))
+ goto raise_overflow;
+ if (is_negative)
+ val = ~val;
+ }
+ ret = 0;
+ done:
+ Py_XDECREF(shift);
+ Py_XDECREF(mask);
+ Py_XDECREF(stepval);
+#endif
+ if (unlikely(ret))
+ return (int) -1;
+ return val;
+ }
+raise_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "value too large to convert to int");
+ return (int) -1;
+raise_neg_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "can't convert negative value to int");
+ return (int) -1;
+}
+
+/* CIntToPy */
+ static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const long neg_one = (long) -1, const_zero = (long) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+ if (is_unsigned) {
+ if (sizeof(long) < sizeof(long)) {
+ return PyInt_FromLong((long) value);
+ } else if (sizeof(long) <= sizeof(unsigned long)) {
+ return PyLong_FromUnsignedLong((unsigned long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) {
+ return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value);
+#endif
+ }
+ } else {
+ if (sizeof(long) <= sizeof(long)) {
+ return PyInt_FromLong((long) value);
+#ifdef HAVE_LONG_LONG
+ } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) {
+ return PyLong_FromLongLong((PY_LONG_LONG) value);
+#endif
+ }
+ }
+ {
+ unsigned char *bytes = (unsigned char *)&value;
+#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4
+ if (is_unsigned) {
+ return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1);
+ } else {
+ return PyLong_FromNativeBytes(bytes, sizeof(value), -1);
+ }
+#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ return _PyLong_FromByteArray(bytes, sizeof(long),
+ little, !is_unsigned);
+#else
+ int one = 1; int little = (int)*(unsigned char *)&one;
+ PyObject *from_bytes, *result = NULL;
+ PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL;
+ from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes");
+ if (!from_bytes) return NULL;
+ py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(long));
+ if (!py_bytes) goto limited_bad;
+ order_str = PyUnicode_FromString(little ? "little" : "big");
+ if (!order_str) goto limited_bad;
+ arg_tuple = PyTuple_Pack(2, py_bytes, order_str);
+ if (!arg_tuple) goto limited_bad;
+ if (!is_unsigned) {
+ kwds = PyDict_New();
+ if (!kwds) goto limited_bad;
+ if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad;
+ }
+ result = PyObject_Call(from_bytes, arg_tuple, kwds);
+ limited_bad:
+ Py_XDECREF(kwds);
+ Py_XDECREF(arg_tuple);
+ Py_XDECREF(order_str);
+ Py_XDECREF(py_bytes);
+ Py_XDECREF(from_bytes);
+ return result;
+#endif
+ }
+}
+
+/* FormatTypeName */
+ #if CYTHON_COMPILING_IN_LIMITED_API
+static __Pyx_TypeName
+__Pyx_PyType_GetName(PyTypeObject* tp)
+{
+ PyObject *name = __Pyx_PyObject_GetAttrStr((PyObject *)tp,
+ __pyx_n_s_name);
+ if (unlikely(name == NULL) || unlikely(!PyUnicode_Check(name))) {
+ PyErr_Clear();
+ Py_XDECREF(name);
+ name = __Pyx_NewRef(__pyx_n_s__15);
+ }
+ return name;
+}
+#endif
+
+/* CIntFromPy */
+ static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) {
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wconversion"
+#endif
+ const long neg_one = (long) -1, const_zero = (long) 0;
+#ifdef __Pyx_HAS_GCC_DIAGNOSTIC
+#pragma GCC diagnostic pop
+#endif
+ const int is_unsigned = neg_one > const_zero;
+#if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_Check(x))) {
+ if ((sizeof(long) < sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x))
+ } else {
+ long val = PyInt_AS_LONG(x);
+ if (is_unsigned && unlikely(val < 0)) {
+ goto raise_neg_overflow;
+ }
+ return (long) val;
+ }
+ }
+#endif
+ if (unlikely(!PyLong_Check(x))) {
+ long val;
+ PyObject *tmp = __Pyx_PyNumber_IntOrLong(x);
+ if (!tmp) return (long) -1;
+ val = __Pyx_PyInt_As_long(tmp);
+ Py_DECREF(tmp);
+ return val;
+ }
+ if (is_unsigned) {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (unlikely(__Pyx_PyLong_IsNeg(x))) {
+ goto raise_neg_overflow;
+ } else if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_DigitCount(x)) {
+ case 2:
+ if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) {
+ return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) {
+ return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) {
+ return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]));
+ }
+ }
+ break;
+ }
+ }
+#endif
+#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7
+ if (unlikely(Py_SIZE(x) < 0)) {
+ goto raise_neg_overflow;
+ }
+#else
+ {
+ int result = PyObject_RichCompareBool(x, Py_False, Py_LT);
+ if (unlikely(result < 0))
+ return (long) -1;
+ if (unlikely(result == 1))
+ goto raise_neg_overflow;
+ }
+#endif
+ if ((sizeof(long) <= sizeof(unsigned long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x))
+#endif
+ }
+ } else {
+#if CYTHON_USE_PYLONG_INTERNALS
+ if (__Pyx_PyLong_IsCompact(x)) {
+ __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x))
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(x);
+ assert(__Pyx_PyLong_DigitCount(x) > 1);
+ switch (__Pyx_PyLong_SignedDigitCount(x)) {
+ case -2:
+ if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) {
+ return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ case 2:
+ if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) {
+ return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ case -3:
+ if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) {
+ return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ case 3:
+ if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) {
+ return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ case -4:
+ if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) {
+ return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ case 4:
+ if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) {
+ if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) {
+ __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])))
+ } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) {
+ return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])));
+ }
+ }
+ break;
+ }
+ }
+#endif
+ if ((sizeof(long) <= sizeof(long))) {
+ __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x))
+#ifdef HAVE_LONG_LONG
+ } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) {
+ __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x))
+#endif
+ }
+ }
+ {
+ long val;
+ int ret = -1;
+#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API
+ Py_ssize_t bytes_copied = PyLong_AsNativeBytes(
+ x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0));
+ if (unlikely(bytes_copied == -1)) {
+ } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) {
+ goto raise_overflow;
+ } else {
+ ret = 0;
+ }
+#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray)
+ int one = 1; int is_little = (int)*(unsigned char *)&one;
+ unsigned char *bytes = (unsigned char *)&val;
+ ret = _PyLong_AsByteArray((PyLongObject *)x,
+ bytes, sizeof(val),
+ is_little, !is_unsigned);
+#else
+ PyObject *v;
+ PyObject *stepval = NULL, *mask = NULL, *shift = NULL;
+ int bits, remaining_bits, is_negative = 0;
+ int chunk_size = (sizeof(long) < 8) ? 30 : 62;
+ if (likely(PyLong_CheckExact(x))) {
+ v = __Pyx_NewRef(x);
+ } else {
+ v = PyNumber_Long(x);
+ if (unlikely(!v)) return (long) -1;
+ assert(PyLong_CheckExact(v));
+ }
+ {
+ int result = PyObject_RichCompareBool(v, Py_False, Py_LT);
+ if (unlikely(result < 0)) {
+ Py_DECREF(v);
+ return (long) -1;
+ }
+ is_negative = result == 1;
+ }
+ if (is_unsigned && unlikely(is_negative)) {
+ Py_DECREF(v);
+ goto raise_neg_overflow;
+ } else if (is_negative) {
+ stepval = PyNumber_Invert(v);
+ Py_DECREF(v);
+ if (unlikely(!stepval))
+ return (long) -1;
+ } else {
+ stepval = v;
+ }
+ v = NULL;
+ val = (long) 0;
+ mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done;
+ shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done;
+ for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) {
+ PyObject *tmp, *digit;
+ long idigit;
+ digit = PyNumber_And(stepval, mask);
+ if (unlikely(!digit)) goto done;
+ idigit = PyLong_AsLong(digit);
+ Py_DECREF(digit);
+ if (unlikely(idigit < 0)) goto done;
+ val |= ((long) idigit) << bits;
+ tmp = PyNumber_Rshift(stepval, shift);
+ if (unlikely(!tmp)) goto done;
+ Py_DECREF(stepval); stepval = tmp;
+ }
+ Py_DECREF(shift); shift = NULL;
+ Py_DECREF(mask); mask = NULL;
+ {
+ long idigit = PyLong_AsLong(stepval);
+ if (unlikely(idigit < 0)) goto done;
+ remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1);
+ if (unlikely(idigit >= (1L << remaining_bits)))
+ goto raise_overflow;
+ val |= ((long) idigit) << bits;
+ }
+ if (!is_unsigned) {
+ if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1))))
+ goto raise_overflow;
+ if (is_negative)
+ val = ~val;
+ }
+ ret = 0;
+ done:
+ Py_XDECREF(shift);
+ Py_XDECREF(mask);
+ Py_XDECREF(stepval);
+#endif
+ if (unlikely(ret))
+ return (long) -1;
+ return val;
+ }
+raise_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "value too large to convert to long");
+ return (long) -1;
+raise_neg_overflow:
+ PyErr_SetString(PyExc_OverflowError,
+ "can't convert negative value to long");
+ return (long) -1;
+}
+
+/* FastTypeChecks */
+ #if CYTHON_COMPILING_IN_CPYTHON
+static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) {
+ while (a) {
+ a = __Pyx_PyType_GetSlot(a, tp_base, PyTypeObject*);
+ if (a == b)
+ return 1;
+ }
+ return b == &PyBaseObject_Type;
+}
+static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) {
+ PyObject *mro;
+ if (a == b) return 1;
+ mro = a->tp_mro;
+ if (likely(mro)) {
+ Py_ssize_t i, n;
+ n = PyTuple_GET_SIZE(mro);
+ for (i = 0; i < n; i++) {
+ if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b)
+ return 1;
+ }
+ return 0;
+ }
+ return __Pyx_InBases(a, b);
+}
+static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b) {
+ PyObject *mro;
+ if (cls == a || cls == b) return 1;
+ mro = cls->tp_mro;
+ if (likely(mro)) {
+ Py_ssize_t i, n;
+ n = PyTuple_GET_SIZE(mro);
+ for (i = 0; i < n; i++) {
+ PyObject *base = PyTuple_GET_ITEM(mro, i);
+ if (base == (PyObject *)a || base == (PyObject *)b)
+ return 1;
+ }
+ return 0;
+ }
+ return __Pyx_InBases(cls, a) || __Pyx_InBases(cls, b);
+}
+#if PY_MAJOR_VERSION == 2
+static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) {
+ PyObject *exception, *value, *tb;
+ int res;
+ __Pyx_PyThreadState_declare
+ __Pyx_PyThreadState_assign
+ __Pyx_ErrFetch(&exception, &value, &tb);
+ res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0;
+ if (unlikely(res == -1)) {
+ PyErr_WriteUnraisable(err);
+ res = 0;
+ }
+ if (!res) {
+ res = PyObject_IsSubclass(err, exc_type2);
+ if (unlikely(res == -1)) {
+ PyErr_WriteUnraisable(err);
+ res = 0;
+ }
+ }
+ __Pyx_ErrRestore(exception, value, tb);
+ return res;
+}
+#else
+static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) {
+ if (exc_type1) {
+ return __Pyx_IsAnySubtype2((PyTypeObject*)err, (PyTypeObject*)exc_type1, (PyTypeObject*)exc_type2);
+ } else {
+ return __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2);
+ }
+}
+#endif
+static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) {
+ Py_ssize_t i, n;
+ assert(PyExceptionClass_Check(exc_type));
+ n = PyTuple_GET_SIZE(tuple);
+#if PY_MAJOR_VERSION >= 3
+ for (i=0; i= 0x030B00A4
+ return Py_Version & ~0xFFUL;
+#else
+ const char* rt_version = Py_GetVersion();
+ unsigned long version = 0;
+ unsigned long factor = 0x01000000UL;
+ unsigned int digit = 0;
+ int i = 0;
+ while (factor) {
+ while ('0' <= rt_version[i] && rt_version[i] <= '9') {
+ digit = digit * 10 + (unsigned int) (rt_version[i] - '0');
+ ++i;
+ }
+ version += factor * digit;
+ if (rt_version[i] != '.')
+ break;
+ digit = 0;
+ factor >>= 8;
+ ++i;
+ }
+ return version;
+#endif
+}
+static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer) {
+ const unsigned long MAJOR_MINOR = 0xFFFF0000UL;
+ if ((rt_version & MAJOR_MINOR) == (ct_version & MAJOR_MINOR))
+ return 0;
+ if (likely(allow_newer && (rt_version & MAJOR_MINOR) > (ct_version & MAJOR_MINOR)))
+ return 1;
+ {
+ char message[200];
+ PyOS_snprintf(message, sizeof(message),
+ "compile time Python version %d.%d "
+ "of module '%.100s' "
+ "%s "
+ "runtime version %d.%d",
+ (int) (ct_version >> 24), (int) ((ct_version >> 16) & 0xFF),
+ __Pyx_MODULE_NAME,
+ (allow_newer) ? "was newer than" : "does not match",
+ (int) (rt_version >> 24), (int) ((rt_version >> 16) & 0xFF)
+ );
+ return PyErr_WarnEx(NULL, message, 1);
+ }
+}
+
+/* InitStrings */
+ #if PY_MAJOR_VERSION >= 3
+static int __Pyx_InitString(__Pyx_StringTabEntry t, PyObject **str) {
+ if (t.is_unicode | t.is_str) {
+ if (t.intern) {
+ *str = PyUnicode_InternFromString(t.s);
+ } else if (t.encoding) {
+ *str = PyUnicode_Decode(t.s, t.n - 1, t.encoding, NULL);
+ } else {
+ *str = PyUnicode_FromStringAndSize(t.s, t.n - 1);
+ }
+ } else {
+ *str = PyBytes_FromStringAndSize(t.s, t.n - 1);
+ }
+ if (!*str)
+ return -1;
+ if (PyObject_Hash(*str) == -1)
+ return -1;
+ return 0;
+}
+#endif
+static int __Pyx_InitStrings(__Pyx_StringTabEntry *t) {
+ while (t->p) {
+ #if PY_MAJOR_VERSION >= 3
+ __Pyx_InitString(*t, t->p);
+ #else
+ if (t->is_unicode) {
+ *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL);
+ } else if (t->intern) {
+ *t->p = PyString_InternFromString(t->s);
+ } else {
+ *t->p = PyString_FromStringAndSize(t->s, t->n - 1);
+ }
+ if (!*t->p)
+ return -1;
+ if (PyObject_Hash(*t->p) == -1)
+ return -1;
+ #endif
+ ++t;
+ }
+ return 0;
+}
+
+#include
+static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s) {
+ size_t len = strlen(s);
+ if (unlikely(len > (size_t) PY_SSIZE_T_MAX)) {
+ PyErr_SetString(PyExc_OverflowError, "byte string is too long");
+ return -1;
+ }
+ return (Py_ssize_t) len;
+}
+static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) {
+ Py_ssize_t len = __Pyx_ssize_strlen(c_str);
+ if (unlikely(len < 0)) return NULL;
+ return __Pyx_PyUnicode_FromStringAndSize(c_str, len);
+}
+static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char* c_str) {
+ Py_ssize_t len = __Pyx_ssize_strlen(c_str);
+ if (unlikely(len < 0)) return NULL;
+ return PyByteArray_FromStringAndSize(c_str, len);
+}
+static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) {
+ Py_ssize_t ignore;
+ return __Pyx_PyObject_AsStringAndSize(o, &ignore);
+}
+#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT
+#if !CYTHON_PEP393_ENABLED
+static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) {
+ char* defenc_c;
+ PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL);
+ if (!defenc) return NULL;
+ defenc_c = PyBytes_AS_STRING(defenc);
+#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII
+ {
+ char* end = defenc_c + PyBytes_GET_SIZE(defenc);
+ char* c;
+ for (c = defenc_c; c < end; c++) {
+ if ((unsigned char) (*c) >= 128) {
+ PyUnicode_AsASCIIString(o);
+ return NULL;
+ }
+ }
+ }
+#endif
+ *length = PyBytes_GET_SIZE(defenc);
+ return defenc_c;
+}
+#else
+static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) {
+ if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL;
+#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII
+ if (likely(PyUnicode_IS_ASCII(o))) {
+ *length = PyUnicode_GET_LENGTH(o);
+ return PyUnicode_AsUTF8(o);
+ } else {
+ PyUnicode_AsASCIIString(o);
+ return NULL;
+ }
+#else
+ return PyUnicode_AsUTF8AndSize(o, length);
+#endif
+}
+#endif
+#endif
+static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) {
+#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT
+ if (
+#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII
+ __Pyx_sys_getdefaultencoding_not_ascii &&
+#endif
+ PyUnicode_Check(o)) {
+ return __Pyx_PyUnicode_AsStringAndSize(o, length);
+ } else
+#endif
+#if (!CYTHON_COMPILING_IN_PYPY && !CYTHON_COMPILING_IN_LIMITED_API) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE))
+ if (PyByteArray_Check(o)) {
+ *length = PyByteArray_GET_SIZE(o);
+ return PyByteArray_AS_STRING(o);
+ } else
+#endif
+ {
+ char* result;
+ int r = PyBytes_AsStringAndSize(o, &result, length);
+ if (unlikely(r < 0)) {
+ return NULL;
+ } else {
+ return result;
+ }
+ }
+}
+static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) {
+ int is_true = x == Py_True;
+ if (is_true | (x == Py_False) | (x == Py_None)) return is_true;
+ else return PyObject_IsTrue(x);
+}
+static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) {
+ int retval;
+ if (unlikely(!x)) return -1;
+ retval = __Pyx_PyObject_IsTrue(x);
+ Py_DECREF(x);
+ return retval;
+}
+static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) {
+ __Pyx_TypeName result_type_name = __Pyx_PyType_GetName(Py_TYPE(result));
+#if PY_MAJOR_VERSION >= 3
+ if (PyLong_Check(result)) {
+ if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
+ "__int__ returned non-int (type " __Pyx_FMT_TYPENAME "). "
+ "The ability to return an instance of a strict subclass of int is deprecated, "
+ "and may be removed in a future version of Python.",
+ result_type_name)) {
+ __Pyx_DECREF_TypeName(result_type_name);
+ Py_DECREF(result);
+ return NULL;
+ }
+ __Pyx_DECREF_TypeName(result_type_name);
+ return result;
+ }
+#endif
+ PyErr_Format(PyExc_TypeError,
+ "__%.4s__ returned non-%.4s (type " __Pyx_FMT_TYPENAME ")",
+ type_name, type_name, result_type_name);
+ __Pyx_DECREF_TypeName(result_type_name);
+ Py_DECREF(result);
+ return NULL;
+}
+static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) {
+#if CYTHON_USE_TYPE_SLOTS
+ PyNumberMethods *m;
+#endif
+ const char *name = NULL;
+ PyObject *res = NULL;
+#if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_Check(x) || PyLong_Check(x)))
+#else
+ if (likely(PyLong_Check(x)))
+#endif
+ return __Pyx_NewRef(x);
+#if CYTHON_USE_TYPE_SLOTS
+ m = Py_TYPE(x)->tp_as_number;
+ #if PY_MAJOR_VERSION < 3
+ if (m && m->nb_int) {
+ name = "int";
+ res = m->nb_int(x);
+ }
+ else if (m && m->nb_long) {
+ name = "long";
+ res = m->nb_long(x);
+ }
+ #else
+ if (likely(m && m->nb_int)) {
+ name = "int";
+ res = m->nb_int(x);
+ }
+ #endif
+#else
+ if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) {
+ res = PyNumber_Int(x);
+ }
+#endif
+ if (likely(res)) {
+#if PY_MAJOR_VERSION < 3
+ if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) {
+#else
+ if (unlikely(!PyLong_CheckExact(res))) {
+#endif
+ return __Pyx_PyNumber_IntOrLongWrongResultType(res, name);
+ }
+ }
+ else if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "an integer is required");
+ }
+ return res;
+}
+static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) {
+ Py_ssize_t ival;
+ PyObject *x;
+#if PY_MAJOR_VERSION < 3
+ if (likely(PyInt_CheckExact(b))) {
+ if (sizeof(Py_ssize_t) >= sizeof(long))
+ return PyInt_AS_LONG(b);
+ else
+ return PyInt_AsSsize_t(b);
+ }
+#endif
+ if (likely(PyLong_CheckExact(b))) {
+ #if CYTHON_USE_PYLONG_INTERNALS
+ if (likely(__Pyx_PyLong_IsCompact(b))) {
+ return __Pyx_PyLong_CompactValue(b);
+ } else {
+ const digit* digits = __Pyx_PyLong_Digits(b);
+ const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(b);
+ switch (size) {
+ case 2:
+ if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) {
+ return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ case -2:
+ if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) {
+ return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ case 3:
+ if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) {
+ return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ case -3:
+ if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) {
+ return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ case 4:
+ if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) {
+ return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ case -4:
+ if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) {
+ return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0]));
+ }
+ break;
+ }
+ }
+ #endif
+ return PyLong_AsSsize_t(b);
+ }
+ x = PyNumber_Index(b);
+ if (!x) return -1;
+ ival = PyInt_AsSsize_t(x);
+ Py_DECREF(x);
+ return ival;
+}
+static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject* o) {
+ if (sizeof(Py_hash_t) == sizeof(Py_ssize_t)) {
+ return (Py_hash_t) __Pyx_PyIndex_AsSsize_t(o);
+#if PY_MAJOR_VERSION < 3
+ } else if (likely(PyInt_CheckExact(o))) {
+ return PyInt_AS_LONG(o);
+#endif
+ } else {
+ Py_ssize_t ival;
+ PyObject *x;
+ x = PyNumber_Index(o);
+ if (!x) return -1;
+ ival = PyInt_AsLong(x);
+ Py_DECREF(x);
+ return ival;
+ }
+}
+static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) {
+ return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False);
+}
+static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) {
+ return PyInt_FromSize_t(ival);
+}
+
+
+/* #### Code section: utility_code_pragmas_end ### */
+#ifdef _MSC_VER
+#pragma warning( pop )
+#endif
+
+
+
+/* #### Code section: end ### */
+#endif /* Py_PYTHON_H */
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py b/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx b/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..5f921bb2e0ad8be2f9f35b59a452327b191fae78
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx
@@ -0,0 +1,163 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import numpy as np
+cimport numpy as np
+
+cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
+ return a if a >= b else b
+
+cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
+ return a if a <= b else b
+
+def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+ cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+ cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+ cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+ cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+ cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+
+ cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
+
+ cdef int ndets = dets.shape[0]
+ cdef np.ndarray[np.int_t, ndim=1] suppressed = \
+ np.zeros((ndets), dtype=np.int)
+
+ # nominal indices
+ cdef int _i, _j
+ # sorted indices
+ cdef int i, j
+ # temp variables for box i's (the box currently under consideration)
+ cdef np.float32_t ix1, iy1, ix2, iy2, iarea
+ # variables for computing overlap with box j (lower scoring box)
+ cdef np.float32_t xx1, yy1, xx2, yy2
+ cdef np.float32_t w, h
+ cdef np.float32_t inter, ovr
+
+ keep = []
+ for _i in range(ndets):
+ i = order[_i]
+ if suppressed[i] == 1:
+ continue
+ keep.append(i)
+ ix1 = x1[i]
+ iy1 = y1[i]
+ ix2 = x2[i]
+ iy2 = y2[i]
+ iarea = areas[i]
+ for _j in range(_i + 1, ndets):
+ j = order[_j]
+ if suppressed[j] == 1:
+ continue
+ xx1 = max(ix1, x1[j])
+ yy1 = max(iy1, y1[j])
+ xx2 = min(ix2, x2[j])
+ yy2 = min(iy2, y2[j])
+ w = max(0.0, xx2 - xx1 + 1)
+ h = max(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (iarea + areas[j] - inter)
+ if ovr >= thresh:
+ suppressed[j] = 1
+
+ return keep
+
+def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):
+ cdef unsigned int N = boxes.shape[0]
+ cdef float iw, ih, box_area
+ cdef float ua
+ cdef int pos = 0
+ cdef float maxscore = 0
+ cdef int maxpos = 0
+ cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
+
+ for i in range(N):
+ maxscore = boxes[i, 4]
+ maxpos = i
+
+ tx1 = boxes[i,0]
+ ty1 = boxes[i,1]
+ tx2 = boxes[i,2]
+ ty2 = boxes[i,3]
+ ts = boxes[i,4]
+
+ pos = i + 1
+ # get max box
+ while pos < N:
+ if maxscore < boxes[pos, 4]:
+ maxscore = boxes[pos, 4]
+ maxpos = pos
+ pos = pos + 1
+
+ # add max box as a detection
+ boxes[i,0] = boxes[maxpos,0]
+ boxes[i,1] = boxes[maxpos,1]
+ boxes[i,2] = boxes[maxpos,2]
+ boxes[i,3] = boxes[maxpos,3]
+ boxes[i,4] = boxes[maxpos,4]
+
+ # swap ith box with position of max box
+ boxes[maxpos,0] = tx1
+ boxes[maxpos,1] = ty1
+ boxes[maxpos,2] = tx2
+ boxes[maxpos,3] = ty2
+ boxes[maxpos,4] = ts
+
+ tx1 = boxes[i,0]
+ ty1 = boxes[i,1]
+ tx2 = boxes[i,2]
+ ty2 = boxes[i,3]
+ ts = boxes[i,4]
+
+ pos = i + 1
+ # NMS iterations, note that N changes if detection boxes fall below threshold
+ while pos < N:
+ x1 = boxes[pos, 0]
+ y1 = boxes[pos, 1]
+ x2 = boxes[pos, 2]
+ y2 = boxes[pos, 3]
+ s = boxes[pos, 4]
+
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ iw = (min(tx2, x2) - max(tx1, x1) + 1)
+ if iw > 0:
+ ih = (min(ty2, y2) - max(ty1, y1) + 1)
+ if ih > 0:
+ ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
+ ov = iw * ih / ua #iou between max box and detection box
+
+ if method == 1: # linear
+ if ov > Nt:
+ weight = 1 - ov
+ else:
+ weight = 1
+ elif method == 2: # gaussian
+ weight = np.exp(-(ov * ov)/sigma)
+ else: # original NMS
+ if ov > Nt:
+ weight = 0
+ else:
+ weight = 1
+
+ boxes[pos, 4] = weight*boxes[pos, 4]
+
+ # if box score falls below threshold, discard the box by swapping with last box
+ # update N
+ if boxes[pos, 4] < threshold:
+ boxes[pos,0] = boxes[N-1, 0]
+ boxes[pos,1] = boxes[N-1, 1]
+ boxes[pos,2] = boxes[N-1, 2]
+ boxes[pos,3] = boxes[N-1, 3]
+ boxes[pos,4] = boxes[N-1, 4]
+ N = N - 1
+ pos = pos - 1
+
+ pos = pos + 1
+
+ keep = [i for i in range(N)]
+ return keep
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp b/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..68b6d42cd88b59496b22a9e77919abe529b09014
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp
@@ -0,0 +1,2 @@
+void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
+ int boxes_dim, float nms_overlap_thresh, int device_id);
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx b/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..59d84afe94e42de3c456b73580ed83358a2b30d8
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx
@@ -0,0 +1,31 @@
+# --------------------------------------------------------
+# Faster R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import numpy as np
+cimport numpy as np
+
+assert sizeof(int) == sizeof(np.int32_t)
+
+cdef extern from "gpu_nms.hpp":
+ void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
+
+def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
+ np.int32_t device_id=0):
+ cdef int boxes_num = dets.shape[0]
+ cdef int boxes_dim = dets.shape[1]
+ cdef int num_out
+ cdef np.ndarray[np.int32_t, ndim=1] \
+ keep = np.zeros(boxes_num, dtype=np.int32)
+ cdef np.ndarray[np.float32_t, ndim=1] \
+ scores = dets[:, 4]
+ cdef np.ndarray[np.int_t, ndim=1] \
+ order = scores.argsort()[::-1]
+ cdef np.ndarray[np.float32_t, ndim=2] \
+ sorted_dets = dets[order, :]
+ _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
+ keep = keep[:num_out]
+ return list(order[keep])
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu b/external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..038a59012f60ebdf1182ecb778eb3b01a69bc5ed
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu
@@ -0,0 +1,144 @@
+// ------------------------------------------------------------------
+// Faster R-CNN
+// Copyright (c) 2015 Microsoft
+// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
+// Written by Shaoqing Ren
+// ------------------------------------------------------------------
+
+#include "gpu_nms.hpp"
+#include
+#include
+
+#define CUDA_CHECK(condition) \
+ /* Code block avoids redefinition of cudaError_t error */ \
+ do { \
+ cudaError_t error = condition; \
+ if (error != cudaSuccess) { \
+ std::cout << cudaGetErrorString(error) << std::endl; \
+ } \
+ } while (0)
+
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+int const threadsPerBlock = sizeof(unsigned long long) * 8;
+
+__device__ inline float devIoU(float const * const a, float const * const b) {
+ float left = max(a[0], b[0]), right = min(a[2], b[2]);
+ float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
+ float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+ float interS = width * height;
+ float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
+ float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+ return interS / (Sa + Sb - interS);
+}
+
+__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
+ const float *dev_boxes, unsigned long long *dev_mask) {
+ const int row_start = blockIdx.y;
+ const int col_start = blockIdx.x;
+
+ // if (row_start > col_start) return;
+
+ const int row_size =
+ min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
+ const int col_size =
+ min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
+
+ __shared__ float block_boxes[threadsPerBlock * 5];
+ if (threadIdx.x < col_size) {
+ block_boxes[threadIdx.x * 5 + 0] =
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
+ block_boxes[threadIdx.x * 5 + 1] =
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
+ block_boxes[threadIdx.x * 5 + 2] =
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
+ block_boxes[threadIdx.x * 5 + 3] =
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
+ block_boxes[threadIdx.x * 5 + 4] =
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
+ }
+ __syncthreads();
+
+ if (threadIdx.x < row_size) {
+ const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
+ const float *cur_box = dev_boxes + cur_box_idx * 5;
+ int i = 0;
+ unsigned long long t = 0;
+ int start = 0;
+ if (row_start == col_start) {
+ start = threadIdx.x + 1;
+ }
+ for (i = start; i < col_size; i++) {
+ if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
+ t |= 1ULL << i;
+ }
+ }
+ const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
+ dev_mask[cur_box_idx * col_blocks + col_start] = t;
+ }
+}
+
+void _set_device(int device_id) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(¤t_device));
+ if (current_device == device_id) {
+ return;
+ }
+ // The call to cudaSetDevice must come before any calls to Get, which
+ // may perform initialization using the GPU.
+ CUDA_CHECK(cudaSetDevice(device_id));
+}
+
+void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
+ int boxes_dim, float nms_overlap_thresh, int device_id) {
+ _set_device(device_id);
+
+ float* boxes_dev = NULL;
+ unsigned long long* mask_dev = NULL;
+
+ const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
+
+ CUDA_CHECK(cudaMalloc(&boxes_dev,
+ boxes_num * boxes_dim * sizeof(float)));
+ CUDA_CHECK(cudaMemcpy(boxes_dev,
+ boxes_host,
+ boxes_num * boxes_dim * sizeof(float),
+ cudaMemcpyHostToDevice));
+
+ CUDA_CHECK(cudaMalloc(&mask_dev,
+ boxes_num * col_blocks * sizeof(unsigned long long)));
+
+ dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
+ DIVUP(boxes_num, threadsPerBlock));
+ dim3 threads(threadsPerBlock);
+ nms_kernel<<>>(boxes_num,
+ nms_overlap_thresh,
+ boxes_dev,
+ mask_dev);
+
+ std::vector mask_host(boxes_num * col_blocks);
+ CUDA_CHECK(cudaMemcpy(&mask_host[0],
+ mask_dev,
+ sizeof(unsigned long long) * boxes_num * col_blocks,
+ cudaMemcpyDeviceToHost));
+
+ std::vector remv(col_blocks);
+ memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
+
+ int num_to_keep = 0;
+ for (int i = 0; i < boxes_num; i++) {
+ int nblock = i / threadsPerBlock;
+ int inblock = i % threadsPerBlock;
+
+ if (!(remv[nblock] & (1ULL << inblock))) {
+ keep_out[num_to_keep++] = i;
+ unsigned long long *p = &mask_host[0] + i * col_blocks;
+ for (int j = nblock; j < col_blocks; j++) {
+ remv[j] |= p[j];
+ }
+ }
+ }
+ *num_out = num_to_keep;
+
+ CUDA_CHECK(cudaFree(boxes_dev));
+ CUDA_CHECK(cudaFree(mask_dev));
+}
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py b/external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..54e7b25fef72b518df6dcf8d6fb78b986796c6e3
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py
@@ -0,0 +1,38 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import numpy as np
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
diff --git a/external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py b/external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d529875fac67e070ea865a1ba0cc3d248847827f
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py
@@ -0,0 +1,15 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+from .nms.cpu_nms import cpu_nms, cpu_soft_nms
+
+def nms(dets, thresh):
+ """Dispatch to either CPU or GPU NMS implementations."""
+
+ if dets.shape[0] == 0:
+ return []
+ return cpu_nms(dets, thresh)
diff --git a/external/landmark_detection/FaceBoxesV2/utils/prior_box.py b/external/landmark_detection/FaceBoxesV2/utils/prior_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5536670afe139de420bc16bd88238fd2a90735b
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/prior_box.py
@@ -0,0 +1,43 @@
+import torch
+from itertools import product as product
+import numpy as np
+from math import ceil
+
+
+class PriorBox(object):
+ def __init__(self, cfg, image_size=None, phase='train'):
+ super(PriorBox, self).__init__()
+ #self.aspect_ratios = cfg['aspect_ratios']
+ self.min_sizes = cfg['min_sizes']
+ self.steps = cfg['steps']
+ self.clip = cfg['clip']
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ if min_size == 32:
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+ elif min_size == 64:
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]]
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+ else:
+ cx = (j + 0.5) * self.steps[k] / self.image_size[1]
+ cy = (i + 0.5) * self.steps[k] / self.image_size[0]
+ anchors += [cx, cy, s_kx, s_ky]
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
diff --git a/external/landmark_detection/FaceBoxesV2/utils/timer.py b/external/landmark_detection/FaceBoxesV2/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4b3b8098a5ad41f8d18d42b6b2fedb694aa5508
--- /dev/null
+++ b/external/landmark_detection/FaceBoxesV2/utils/timer.py
@@ -0,0 +1,40 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import time
+
+
+class Timer(object):
+ """A simple timer."""
+ def __init__(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.diff = time.time() - self.start_time
+ self.total_time += self.diff
+ self.calls += 1
+ self.average_time = self.total_time / self.calls
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def clear(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
diff --git a/external/landmark_detection/README.md b/external/landmark_detection/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a41395b1fc98636bd762e551a15f577cdce048e5
--- /dev/null
+++ b/external/landmark_detection/README.md
@@ -0,0 +1,110 @@
+# STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection.
+
+Paper Link: [arxiv](https://arxiv.org/abs/2306.02763) | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_STAR_Loss_Reducing_Semantic_Ambiguity_in_Facial_Landmark_Detection_CVPR_2023_paper.pdf)
+
+
+- Pytorch implementation of **S**elf-adap**T**ive **A**mbiguity **R**eduction (**STAR**) loss.
+- STAR loss is a self-adaptive anisotropic direction loss, which can be used in heatmap regression-based methods for facial landmark detection.
+- Specifically, we find that semantic ambiguity results in the anisotropic predicted distribution, which inspires us to use predicted distribution to represent semantic ambiguity. So, we use PCA to indicate the character of the predicted distribution and indirectly formulate the direction and intensity of semantic ambiguity. Based on this, STAR loss adaptively suppresses the prediction error in the ambiguity direction to mitigate the impact of ambiguity annotation in training. More details can be found in our paper.
+
+
+
+
+
+
+
+## Dependencies
+
+* python==3.7.3
+* PyTorch=1.6.0
+* requirements.txt
+
+## Dataset Preparation
+
+ - Step1: Download the raw images from [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/#dataset), [300W](https://ibug.doc.ic.ac.uk/resources/300-W/), and [WFLW](https://wywu.github.io/projects/LAB/WFLW.html).
+ - Step2: We follow the data preprocess in [ADNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Huang_ADNet_Leveraging_Error-Bias_Towards_Normal_Direction_in_Face_Alignment_ICCV_2021_paper.pdf), and the metadata can be download from [the corresponding repository](https://github.com/huangyangyu/ADNet).
+ - Step3: Make them look like this:
+```script
+# the dataset directory:
+|-- ${image_dir}
+ |-- WFLW
+ | -- WFLW_images
+ |-- 300W
+ | -- afw
+ | -- helen
+ | -- ibug
+ | -- lfpw
+ |-- COFW
+ | -- train
+ | -- test
+|-- ${annot_dir}
+ |-- WFLW
+ |-- train.tsv, test.tsv
+ |-- 300W
+ |-- train.tsv, test.tsv
+ |--COFW
+ |-- train.tsv, test.tsv
+```
+
+## Usage
+* Work directory: set the ${ckpt_dir} in ./conf/alignment.py.
+* Pretrained model:
+
+| Dataset | Model |
+|:-----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| WFLW | [google](https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view?usp=sharing) / [baidu](https://pan.baidu.com/s/10vvI-ovs3x9NrdmpnXK6sg?pwd=u0yu) |
+| 300W | [google](https://drive.google.com/file/d/1Fiu3hjjkQRdKsWE9IgyNPdiJSz9_MzA5/view?usp=sharing) / [baidu](https://pan.baidu.com/s/1bjUhLq1zS1XSl1nX78fU7A?pwd=yb2s) |
+| COFW | [google](https://drive.google.com/file/d/1NFcZ9jzql_jnn3ulaSzUlyhS05HWB9n_/view?usp=drive_link) / [baidu](https://pan.baidu.com/s/1XO6hDZ8siJLTgFcpyu1Tzw?pwd=m57n) |
+
+
+### Training
+```shell
+python main.py --mode=train --device_ids=0,1,2,3 \
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
+ --data_definition={WFLW, 300W, COFW}
+```
+
+### Testing
+```shell
+python main.py --mode=test --device_ids=0 \
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
+ --data_definition={WFLW, 300W, COFW} \
+ --pretrained_weight=${model_path} \
+```
+
+### Evaluation
+```shell
+python evaluate.py --device_ids=0 \
+ --model_path=${model_path} --metadata_path=${metadata_path} \
+ --image_dir=${image_dir} --data_definition={WFLW, 300W, COFW} \
+```
+
+To test on your own image, the following code could be considered:
+```shell
+python demo.py
+```
+
+
+## Results
+The models trained by STAR Loss achieved **SOTA** performance in all of COFW, 300W and WFLW datasets.
+
+
+
+
+
+## BibTeX Citation
+Please consider citing our papers in your publications if the project helps your research. BibTeX reference is as follows.
+```
+@inproceedings{Zhou_2023_CVPR,
+ author = {Zhou, Zhenglin and Li, Huaxia and Liu, Hong and Wang, Nanyang and Yu, Gang and Ji, Rongrong},
+ title = {STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2023},
+ pages = {15475-15484}
+}
+```
+
+## Acknowledgments
+This repository is built on top of [ADNet](https://github.com/huangyangyu/ADNet).
+Thanks for this strong baseline.
diff --git a/external/landmark_detection/conf/__init__.py b/external/landmark_detection/conf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f92d0e82f402d1599c16deb5d1f0c3bb568bfb3
--- /dev/null
+++ b/external/landmark_detection/conf/__init__.py
@@ -0,0 +1 @@
+from .alignment import Alignment
\ No newline at end of file
diff --git a/external/landmark_detection/conf/alignment.py b/external/landmark_detection/conf/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebaa1d7bd02218f6f9616e4635bd530bddd619d
--- /dev/null
+++ b/external/landmark_detection/conf/alignment.py
@@ -0,0 +1,239 @@
+import os.path as osp
+from .base import Base
+
+
+class Alignment(Base):
+ """
+ Alignment configure file, which contains training parameters of alignment.
+ """
+
+ def __init__(self, args):
+ super(Alignment, self).__init__('alignment')
+ self.ckpt_dir = '/mnt/workspace/humanAIGC/project/STAR/weights'
+ self.net = "stackedHGnet_v1"
+ self.nstack = 4
+ self.loader_type = "alignment"
+ self.data_definition = "300W" # COFW, 300W, WFLW
+ self.test_file = "test.tsv"
+
+ # image
+ self.channels = 3
+ self.width = 256
+ self.height = 256
+ self.means = (127.5, 127.5, 127.5)
+ self.scale = 1 / 127.5
+ self.aug_prob = 1.0
+
+ self.display_iteration = 10
+ self.val_epoch = 1
+ self.valset = "test.tsv"
+ self.norm_type = 'default'
+ self.encoder_type = 'default'
+ self.decoder_type = 'default'
+
+ # scheduler & optimizer
+ self.milestones = [200, 350, 450]
+ self.max_epoch = 260
+ self.optimizer = "adam"
+ self.learn_rate = 0.001
+ self.weight_decay = 0.00001
+ self.betas = [0.9, 0.999]
+ self.gamma = 0.1
+
+ # batch_size & workers
+ self.batch_size = 32
+ self.train_num_workers = 16
+ self.val_batch_size = 32
+ self.val_num_workers = 16
+ self.test_batch_size = 16
+ self.test_num_workers = 0
+
+ # tricks
+ self.ema = True
+ self.add_coord = True
+ self.use_AAM = True
+
+ # loss
+ self.loss_func = "STARLoss_v2"
+
+ # STAR Loss paras
+ self.star_w = 1
+ self.star_dist = 'smoothl1'
+
+ self.init_from_args(args)
+
+ # COFW
+ if self.data_definition == "COFW":
+ self.edge_info = (
+ (True, (0, 4, 2, 5)), # RightEyebrow
+ (True, (1, 6, 3, 7)), # LeftEyebrow
+ (True, (8, 12, 10, 13)), # RightEye
+ (False, (9, 14, 11, 15)), # LeftEye
+ (True, (18, 20, 19, 21)), # Nose
+ (True, (22, 26, 23, 27)), # LowerLip
+ (True, (22, 24, 23, 25)), # UpperLip
+ )
+ if self.norm_type == 'ocular':
+ self.nme_left_index = 8 # ocular
+ self.nme_right_index = 9 # ocular
+ elif self.norm_type in ['pupil', 'default']:
+ self.nme_left_index = 16 # pupil
+ self.nme_right_index = 17 # pupil
+ else:
+ raise NotImplementedError
+ self.classes_num = [29, 7, 29]
+ self.crop_op = True
+ self.flip_mapping = (
+ [0, 1], [4, 6], [2, 3], [5, 7], [8, 9], [10, 11], [12, 14], [16, 17], [13, 15], [18, 19], [22, 23],
+ )
+ self.image_dir = osp.join(self.image_dir, 'COFW')
+ # 300W
+ elif self.data_definition == "300W":
+ self.edge_info = (
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
+ (False, (27, 28, 29, 30)), # NoseLine
+ (False, (31, 32, 33, 34, 35)), # Nose
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
+ )
+ if self.norm_type in ['ocular', 'default']:
+ self.nme_left_index = 36 # ocular
+ self.nme_right_index = 45 # ocular
+ elif self.norm_type == 'pupil':
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
+ else:
+ raise NotImplementedError
+ self.classes_num = [68, 9, 68]
+ self.crop_op = True
+ self.flip_mapping = (
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
+ [31, 35], [32, 34],
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
+ )
+ self.image_dir = osp.join(self.image_dir, '300W')
+ # self.image_dir = osp.join(self.image_dir, '300VW_images')
+ # 300VW
+ elif self.data_definition == "300VW":
+ self.edge_info = (
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
+ (False, (27, 28, 29, 30)), # NoseLine
+ (False, (31, 32, 33, 34, 35)), # Nose
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
+ )
+ if self.norm_type in ['ocular', 'default']:
+ self.nme_left_index = 36 # ocular
+ self.nme_right_index = 45 # ocular
+ elif self.norm_type == 'pupil':
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
+ else:
+ raise NotImplementedError
+ self.classes_num = [68, 9, 68]
+ self.crop_op = True
+ self.flip_mapping = (
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
+ [31, 35], [32, 34],
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
+ )
+ self.image_dir = osp.join(self.image_dir, '300VW_Dataset_2015_12_14')
+ # WFLW
+ elif self.data_definition == "WFLW":
+ self.edge_info = (
+ (False, (
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
+ 27,
+ 28, 29, 30, 31, 32)), # FaceContour
+ (True, (33, 34, 35, 36, 37, 38, 39, 40, 41)), # RightEyebrow
+ (True, (42, 43, 44, 45, 46, 47, 48, 49, 50)), # LeftEyebrow
+ (False, (51, 52, 53, 54)), # NoseLine
+ (False, (55, 56, 57, 58, 59)), # Nose
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # RightEye
+ (True, (68, 69, 70, 71, 72, 73, 74, 75)), # LeftEye
+ (True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)), # OuterLip
+ (True, (88, 89, 90, 91, 92, 93, 94, 95)), # InnerLip
+ )
+ if self.norm_type in ['ocular', 'default']:
+ self.nme_left_index = 60 # ocular
+ self.nme_right_index = 72 # ocular
+ elif self.norm_type == 'pupil':
+ self.nme_left_index = 96 # pupils
+ self.nme_right_index = 97 # pupils
+ else:
+ raise NotImplementedError
+ self.classes_num = [98, 9, 98]
+ self.crop_op = True
+ self.flip_mapping = (
+ [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
+ [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], # cheek
+ [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], # elbrow
+ [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73],
+ [55, 59], [56, 58],
+ [76, 82], [77, 81], [78, 80], [87, 83], [86, 84],
+ [88, 92], [89, 91], [95, 93], [96, 97]
+ )
+ self.image_dir = osp.join(self.image_dir, 'WFLW', 'WFLW_images')
+
+ self.label_num = self.nstack * 3 if self.use_AAM else self.nstack
+ self.loss_weights, self.criterions, self.metrics = [], [], []
+ for i in range(self.nstack):
+ factor = (2 ** i) / (2 ** (self.nstack - 1))
+ if self.use_AAM:
+ self.loss_weights += [factor * weight for weight in [1.0, 10.0, 10.0]]
+ self.criterions += [self.loss_func, "AWingLoss", "AWingLoss"]
+ self.metrics += ["NME", None, None]
+ else:
+ self.loss_weights += [factor * weight for weight in [1.0]]
+ self.criterions += [self.loss_func, ]
+ self.metrics += ["NME", ]
+
+ self.key_metric_index = (self.nstack - 1) * 3 if self.use_AAM else (self.nstack - 1)
+
+ # data
+ self.folder = self.get_foldername()
+ self.work_dir = osp.join(self.ckpt_dir, self.data_definition, self.folder)
+ self.model_dir = osp.join(self.work_dir, 'model')
+ self.log_dir = osp.join(self.work_dir, 'log')
+
+ self.train_tsv_file = osp.join(self.annot_dir, self.data_definition, "train.tsv")
+ self.train_pic_dir = self.image_dir
+
+ self.val_tsv_file = osp.join(self.annot_dir, self.data_definition, self.valset)
+ self.val_pic_dir = self.image_dir
+
+ self.test_tsv_file = osp.join(self.annot_dir, self.data_definition, self.test_file)
+ self.test_pic_dir = self.image_dir
+
+ # self.train_tsv_file = osp.join(self.annot_dir, '300VW', "train.tsv")
+ # self.train_pic_dir = self.image_dir
+
+ # self.val_tsv_file = osp.join(self.annot_dir, '300VW', self.valset)
+ # self.val_pic_dir = self.image_dir
+
+ # self.test_tsv_file = osp.join(self.annot_dir, '300VW', self.test_file)
+ # self.test_pic_dir = self.image_dir
+
+
+ def get_foldername(self):
+ str = ''
+ str += '{}_{}x{}_{}_ep{}_lr{}_bs{}'.format(self.data_definition, self.height, self.width,
+ self.optimizer, self.max_epoch, self.learn_rate, self.batch_size)
+ str += '_{}'.format(self.loss_func)
+ str += '_{}_{}'.format(self.star_dist, self.star_w) if self.loss_func == 'STARLoss' else ''
+ str += '_AAM' if self.use_AAM else ''
+ str += '_{}'.format(self.valset[:-4]) if self.valset != 'test.tsv' else ''
+ str += '_{}'.format(self.id)
+ return str
diff --git a/external/landmark_detection/conf/base.py b/external/landmark_detection/conf/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..55aded090aa40fa934354d0ec1fc7df5823edef2
--- /dev/null
+++ b/external/landmark_detection/conf/base.py
@@ -0,0 +1,94 @@
+import uuid
+import logging
+import os.path as osp
+from argparse import Namespace
+# from tensorboardX import SummaryWriter
+
+class Base:
+ """
+ Base configure file, which contains the basic training parameters and should be inherited by other attribute configure file.
+ """
+
+ def __init__(self, config_name, ckpt_dir='./', image_dir='./', annot_dir='./'):
+ self.type = config_name
+ self.id = str(uuid.uuid4())
+ self.note = ""
+
+ self.ckpt_dir = ckpt_dir
+ self.image_dir = image_dir
+ self.annot_dir = annot_dir
+
+ self.loader_type = "alignment"
+ self.loss_func = "STARLoss"
+
+ # train
+ self.batch_size = 128
+ self.val_batch_size = 1
+ self.test_batch_size = 32
+ self.channels = 3
+ self.width = 256
+ self.height = 256
+
+ # mean values in r, g, b channel.
+ self.means = (127, 127, 127)
+ self.scale = 0.0078125
+
+ self.display_iteration = 100
+ self.milestones = [50, 80]
+ self.max_epoch = 100
+
+ self.net = "stackedHGnet_v1"
+ self.nstack = 4
+
+ # ["adam", "sgd"]
+ self.optimizer = "adam"
+ self.learn_rate = 0.1
+ self.momentum = 0.01 # caffe: 0.99
+ self.weight_decay = 0.0
+ self.nesterov = False
+ self.scheduler = "MultiStepLR"
+ self.gamma = 0.1
+
+ self.loss_weights = [1.0]
+ self.criterions = ["SoftmaxWithLoss"]
+ self.metrics = ["Accuracy"]
+ self.key_metric_index = 0
+ self.classes_num = [1000]
+ self.label_num = len(self.classes_num)
+
+ # model
+ self.ema = False
+ self.use_AAM = True
+
+ # visualization
+ self.writer = None
+
+ # log file
+ self.logger = None
+
+ def init_instance(self):
+ # self.writer = SummaryWriter(logdir=self.log_dir, comment=self.type)
+ log_formatter = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s")
+ root_logger = logging.getLogger()
+ file_handler = logging.FileHandler(osp.join(self.log_dir, "log.txt"))
+ file_handler.setFormatter(log_formatter)
+ file_handler.setLevel(logging.NOTSET)
+ root_logger.addHandler(file_handler)
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(log_formatter)
+ console_handler.setLevel(logging.NOTSET)
+ root_logger.addHandler(console_handler)
+ root_logger.setLevel(logging.NOTSET)
+ self.logger = root_logger
+
+ def __del__(self):
+ # tensorboard --logdir self.log_dir
+ if self.writer is not None:
+ # self.writer.export_scalars_to_json(self.log_dir + "visual.json")
+ self.writer.close()
+
+ def init_from_args(self, args: Namespace):
+ args_vars = vars(args)
+ for key, value in args_vars.items():
+ if hasattr(self, key) and value is not None:
+ setattr(self, key, value)
diff --git a/external/landmark_detection/config.json b/external/landmark_detection/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..35831f0d94065ba9b748dfeab3a3bf1aa25f1de3
--- /dev/null
+++ b/external/landmark_detection/config.json
@@ -0,0 +1,15 @@
+{
+ "Token":"bpt4JPotFA6bpdknR9ZDCw",
+ "business_flag": "shadow_cv_face",
+ "model_local_file_path": "/apdcephfs_cq3/share_1134483/charlinzhou/Documents/awesome-tools/jizhi/",
+ "host_num": 1,
+ "host_gpu_num": 1,
+ "GPUName": "V100",
+ "is_elasticity": true,
+ "enable_evicted_pulled_up": true,
+ "task_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
+ "task_flag": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
+ "model_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
+ "image_full_name": "mirrors.tencent.com/haroldzcli/py36-pytorch1.7.1-torchvision0.8.2-cuda10.1-cudnn7.6",
+ "start_cmd": "./start_slpt.sh /apdcephfs_cq3/share_1134483/charlinzhou/Documents/SLPT_Training train.py --loss_func=star --bb_init --eigen_box --dist_func=align_smoothl1"
+}
diff --git a/external/landmark_detection/data_processor/CheckFaceKeyPoint.py b/external/landmark_detection/data_processor/CheckFaceKeyPoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..edacbae3f8d8a40d5abd012e63a07f1e666e0dbe
--- /dev/null
+++ b/external/landmark_detection/data_processor/CheckFaceKeyPoint.py
@@ -0,0 +1,147 @@
+import os
+
+import cv2
+import numpy as np
+from PIL import Image
+
+selected_indices_old = [
+ 2311,
+ 2416,
+ 2437,
+ 2460,
+ 2495,
+ 2518,
+ 2520,
+ 2627,
+ 4285,
+ 4315,
+ 6223,
+ 6457,
+ 6597,
+ 6642,
+ 6974,
+ 7054,
+ 7064,
+ 7182,
+ 7303,
+ 7334,
+ 7351,
+ 7368,
+ 7374,
+ 7493,
+ 7503,
+ 7626,
+ 8443,
+ 8562,
+ 8597,
+ 8701,
+ 8817,
+ 8953,
+ 11213,
+ 11261,
+ 11317,
+ 11384,
+ 11600,
+ 11755,
+ 11852,
+ 11891,
+ 11945,
+ 12010,
+ 12354,
+ 12534,
+ 12736,
+ 12880,
+ 12892,
+ 13004,
+ 13323,
+ 13371,
+ 13534,
+ 13575,
+ 14874,
+ 14949,
+ 14977,
+ 15052,
+ 15076,
+ 15291,
+ 15620,
+ 15758,
+ 16309,
+ 16325,
+ 16348,
+ 16390,
+ 16489,
+ 16665,
+ 16891,
+ 17147,
+ 17183,
+ 17488,
+ 17549,
+ 17657,
+ 17932,
+ 19661,
+ 20162,
+ 20200,
+ 20238,
+ 20286,
+ 20432,
+ 20834,
+ 20954,
+ 21015,
+ 21036,
+ 21117,
+ 21299,
+ 21611,
+ 21632,
+ 21649,
+ 22722,
+ 22759,
+ 22873,
+ 23028,
+ 23033,
+ 23082,
+ 23187,
+ 23232,
+ 23302,
+ 23413,
+ 23430,
+ 23446,
+ 23457,
+ 23548,
+ 23636,
+ 32060,
+ 32245,
+]
+
+selected_indices = list()
+with open('/home/gyalex/Desktop/face_anno.txt', 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ hh = line.strip().split()
+ if len(hh) > 0:
+ pid = hh[0].find('.')
+ if pid != -1:
+ s = hh[0][pid+1:len(hh[0])]
+ print(s)
+ selected_indices.append(int(s))
+
+f.close()
+
+dir = '/media/gyalex/Data/face_ldk_dataset/MHC_LightingPreset_Portrait_RT_0_19/MHC_LightingPreset_Portrait_RT_seq_000015'
+
+for idx in range(500):
+ img = os.path.join(dir, "view_1/MHC_LightingPreset_Portrait_RT_seq_000015_FinalImage_" + str(idx).zfill(4) + ".jpeg")
+ lmd = os.path.join(dir, "mesh/mesh_screen" + str(idx+5).zfill(7) + ".npy")
+
+ img = cv2.imread(img)
+ # c = 511 / 2
+ # lmd = np.load(lmd) * c + c
+ # lmd[:, 1] = 511 - lmd[:, 1]
+ lmd = np.load(lmd)[selected_indices]
+ for i in range(lmd.shape[0]):
+ p = lmd[i]
+ x, y = round(float(p[0])), round(float(p[1]))
+ print(p)
+ cv2.circle(img, (x, y), 2, (0, 0, 255), -1)
+
+ cv2.imshow('win', img)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/external/landmark_detection/data_processor/align.py b/external/landmark_detection/data_processor/align.py
new file mode 100644
index 0000000000000000000000000000000000000000..be9920e896c5034b87851118e06fe3195debbf8a
--- /dev/null
+++ b/external/landmark_detection/data_processor/align.py
@@ -0,0 +1,193 @@
+import numpy as np
+import open3d as o3d
+from scipy.spatial.transform import Rotation
+from scipy.linalg import orthogonal_procrustes
+
+from open3d.pipelines.registration import registration_ransac_based_on_correspondence
+
+
+def rigid_transform_3D(A, B):
+ assert A.shape == B.shape, "Input arrays must have the same shape"
+ assert A.shape[1] == 3, "Input arrays must be Nx3"
+
+ N = A.shape[0] # Number of points
+
+ # Compute centroids of A and B
+ centroid_A = np.mean(A, axis=0)
+ centroid_B = np.mean(B, axis=0)
+
+ # Center the points around the centroids
+ AA = A - centroid_A
+ BB = B - centroid_B
+
+ # H = AA^T * BB
+ H = np.dot(AA.T, BB)
+
+ # Singular Value Decomposition
+ U, S, Vt = np.linalg.svd(H)
+
+ # Compute rotation
+ R = np.dot(Vt.T, U.T)
+
+ # Ensure a proper rotation (det(R) should be +1)
+ if np.linalg.det(R) < 0:
+ Vt[2, :] *= -1
+ R = np.dot(Vt.T, U.T)
+
+ # Compute translation
+ t = centroid_B - np.dot(R, centroid_A)
+
+ # Construct the transform matrix (4x4)
+ transform_matrix = np.eye(4)
+ transform_matrix[:3, :3] = R
+ transform_matrix[:3, 3] = t
+
+ return transform_matrix
+
+
+def compute_rigid_transform(points1, points2):
+ """
+ 计算从points1到points2的刚体变换(包括尺度、旋转和平移)。
+
+ 参数:
+ points1, points2: np.ndarray, 形状为(68, 3)的数组,分别为两组3D对应点。
+
+ 返回:
+ scale: float, 尺度因子
+ R: np.ndarray, 3x3的旋转矩阵
+ t: np.ndarray, 3维的平移向量
+ """
+ # 中心化
+ mean1 = np.mean(points1, axis=0)
+ centered_points1 = points1 - mean1
+ mean2 = np.mean(points2, axis=0)
+ centered_points2 = points2 - mean2
+
+ # 使用orthogonal_procrustes计算旋转和平移
+ R, _ = orthogonal_procrustes(centered_points1, centered_points2)
+ t = mean2 - R @ mean1 # 计算平移向量
+
+ # 计算尺度因子
+ scale = np.mean(np.linalg.norm(centered_points2, axis=1) /
+ np.linalg.norm(centered_points1, axis=1))
+
+ return scale, R, t
+
+
+def compute_rigid_transform_new(points_A, points_B):
+ # 中心化
+ center_A = np.mean(points_A, axis=0)
+ center_B = np.mean(points_B, axis=0)
+ points_A_centered = points_A - center_A
+ points_B_centered = points_B - center_B
+
+ # 计算协方差矩阵
+ cov_matrix = np.dot(points_A_centered.T, points_B_centered)
+
+ # SVD分解
+ U, S, Vt = np.linalg.svd(cov_matrix)
+
+ # 确保旋转矩阵为正交且右手系,这里我们取Vt的转置作为旋转矩阵
+ rotation_matrix = np.dot(Vt.T, U.T)
+
+ # 检查行列式是否为-1(表示反射,不满足旋转矩阵要求),如果是,则调整一个列的符号
+ if np.linalg.det(rotation_matrix) < 0:
+ Vt[2,:] *= -1
+ rotation_matrix = np.dot(Vt.T, U.T)
+
+ # 计算尺度因子
+ scale = np.trace(np.dot(points_A_centered.T, points_B_centered)) / np.trace(np.dot(points_A_centered.T, points_A_centered))
+
+ # 计算平移向量
+ translation_vector = center_B - scale * np.dot(rotation_matrix, center_A)
+
+ return scale, rotation_matrix, translation_vector
+
+
+
+
+# 示范用法
+obj_A = '/home/gyalex/Desktop/our_face.obj'
+obj_B = '/home/gyalex/Desktop/Neutral.obj'
+
+mesh_A = o3d.io.read_triangle_mesh(obj_A)
+mesh_B = o3d.io.read_triangle_mesh(obj_B)
+
+vertices_A = np.asarray(mesh_A.vertices)
+vertices_B = np.asarray(mesh_B.vertices)
+
+list_A = list()
+list_B = list()
+with open('/home/gyalex/Desktop/our_marker.txt', 'r') as f:
+ lines_A = f.readlines()
+ for line in lines_A:
+ hh = line.strip().split()
+ list_A.append(int(hh[0]))
+
+with open('/home/gyalex/Desktop/ARKit_landmarks.txt', 'r') as f:
+ lines_B = f.readlines()
+ for line in lines_B:
+ hh = line.strip().split()
+ list_B.append(int(hh[0]))
+
+A = vertices_A[list_A,:] # 第一组3D点
+B = vertices_B[list_B,:] # 第二组3D点
+
+# scale, R, t = compute_rigid_transform(A, B)
+
+# # 定义尺度变换矩阵
+# scale_matrix = np.eye(4)
+# scale_matrix[0, 0] = scale # x轴方向放大2倍
+# scale_matrix[1, 1] = scale # y轴方向放大2倍
+# scale_matrix[2, 2] = scale # z轴方向放大2倍
+
+# transform_matrix = np.eye(4)
+# transform_matrix[:3, :3] = scale
+# transform_matrix[:3, 3] = R*t
+
+# mesh_A.transform(transform_matrix)
+# # mesh_A.transform(scale_matrix)
+
+# o3d.io.write_triangle_mesh('/home/gyalex/Desktop/our_face_new.obj', mesh_A)
+
+pcd_source = o3d.utility.Vector3dVector(A) # 示例源点云数据
+pcd_target = o3d.utility.Vector3dVector(B) # 示例目标点云数据 + 1偏移,仅作示例
+
+corres_source = list()
+for idx in range(68): corres_source.append(idx)
+corres_target = list()
+for idx in range(68): corres_target.append(idx)
+
+# 根据对应点索引获取实际的对应点坐标
+corres_source_points = pcd_source
+corres_target_points = pcd_target
+
+corres = o3d.utility.Vector2iVector([[src, tgt] for src, tgt in zip(corres_source, corres_target)])
+
+# 应用RANSAC进行基于对应点的配准
+reg_result = registration_ransac_based_on_correspondence(
+ pcd_source,
+ pcd_target,
+ corres,
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+ ransac_n=3,
+ criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=100000, epsilon=1e-6)
+)
+
+# # 使用RANSAC进行配准
+# convergence_criteria = o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=50000, max_validation=500)
+# ransac_result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
+# pcd_source,
+# pcd_target,
+# corres,
+# o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+# 3, # RANSAC阈值,根据实际情况调整
+# convergence_criteria,
+# [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
+# o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.05)],
+# o3d.pipelines.registration.RANSACLoss())
+
+# 应用变换到源mesh
+# mesh_source_aligned = mesh_source.transform(reg_result.transformation)
+
+a = 0
\ No newline at end of file
diff --git a/external/landmark_detection/data_processor/process_pcd.py b/external/landmark_detection/data_processor/process_pcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6183ab658933d6b0d5a0994c713e561c7d2a5be
--- /dev/null
+++ b/external/landmark_detection/data_processor/process_pcd.py
@@ -0,0 +1,250 @@
+import os
+import cv2
+import numpy as np
+import open3d as o3d
+# import pyrender
+# from pyrender import mesh, DirectionalLight, Material, PerspectiveCamera
+
+os.environ['__GL_THREADED_OPTIMIZATIONS'] = '1'
+
+cord_list = []
+with open('./cord.txt', 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ m = line.split()
+ x = int(m[0])
+ y = int(m[1])
+
+ x = 1000 - x
+ y = 1000 - y
+
+ cord_list.append([x, y])
+
+
+# 假设TXT文件的路径
+output_folder = '/media/gyalex/Data/face_det_dataset/rgbd_data/rgbd'
+if not os.path.exists(output_folder):
+ os.mkdir(output_folder)
+
+for idx in range(32, 33):
+ txt_file_path = '/media/gyalex/Data/face_det_dataset/rgbd_data/PointImage'+ str(idx) + '.txt'
+ _, name = os.path.split(txt_file_path)
+ print(txt_file_path)
+
+ with open(txt_file_path, 'r') as file:
+ points = []
+ rgb_list = []
+ ori_rgb_list = []
+ normal_list = []
+
+ # 逐行读取数据
+ for line in file:
+ # 去除行尾的换行符并分割字符串
+ x, y, z, r, g, b, nx, ny, nz, w = line.split()
+ # 将字符串转换为浮点数
+ x = float(x)
+ y = float(y)
+ z = float(z)
+ r = float(r)
+ g = float(g)
+ b = float(b)
+ nx = float(nx)
+ ny = float(ny)
+ nz = float(nz)
+ # 将点添加到列表中
+ points.append((x, y, z))
+ rgb_list.append((r/255.0, g/255.0 , b/255.0))
+ normal_list.append((nx, ny, nz))
+
+ ori_r = int(r)
+ ori_g = int(g)
+ ori_b = int(b)
+ ori_rgb_list.append((ori_r, ori_g , ori_b))
+
+ np_points = np.asarray(points)
+
+ np_points_a = np_points
+
+ np_colors = np.asarray(rgb_list)
+ np_normals = np.asarray(normal_list)
+
+ np_colors_ori = np.asarray(ori_rgb_list)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(np_points)
+ pcd.colors = o3d.utility.Vector3dVector(np_colors)
+ pcd.normals = o3d.utility.Vector3dVector(np_normals)
+
+ map_dict = {}
+
+ image = np.ones((1000, 1000, 3),dtype=np.uint8)*255
+ for i in range(np.array(pcd.points).shape[0]):
+ x = np.array(pcd.points)[i,0]+400
+ y = np.array(pcd.points)[i,1]+400
+
+ image[int(x),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
+ image[int(x+1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
+ image[int(x),int(y+1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
+ image[int(x-1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
+ image[int(x),int(y-1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
+
+ map_dict[str(int(x)) + '_' + str(int(y))] = i
+ map_dict[str(int(x+1)) + '_' + str(int(y))] = i
+ map_dict[str(int(x)) + '_' + str(int(y+1))] = i
+ map_dict[str(int(x-1)) + '_' + str(int(y))] = i
+ map_dict[str(int(x)) + '_' + str(int(y-1))] = i
+
+ # if [int(y), int(x)] in cord_list:
+ # image[int(x),int(y),:] = np.array([0, 255, 0])
+
+ # if [int(y), int(x+1)] in cord_list:
+ # image[int(x+1),int(y),:] = np.array([0, 255, 0])
+
+ # if [int(y+1), int(x)] in cord_list:
+ # image[int(x),int(y+1),:] = np.array([0, 255, 0])
+
+ # if [int(y), int(x-1)] in cord_list:
+ # image[int(x-1),int(y),:] = np.array([0, 255, 0])
+
+ # if [int(y-1), int(x)] in cord_list:
+ # image[int(x),int(y-1),:] = np.array([0, 255, 0])
+
+ # if [int(y-1), int(x-1)] in cord_list:
+ # image[int(x-1),int(y-1),:] = np.array([0, 255, 0])
+
+ # if [int(y+1), int(x+1)] in cord_list:
+ # image[int(x+1),int(y+1),:] = np.array([0, 255, 0])
+
+ h_list = []
+ for m in cord_list:
+ a, b = m[0], m[1]
+ c = image[int(b),int(a),:][0]
+
+ flag = False
+
+ if image[int(b),int(a),:][1] != 255:
+ h_list.append(str(int(b))+'_'+str(int(a)))
+ flag = True
+ else:
+ if image[int(b)-2,int(a)-2,:][1] != 255:
+ h_list.append(str(int(b)-2)+'_'+str(int(a)-2))
+ flag = True
+ elif image[int(b)+2,int(a)+2,:][1] != 255:
+ h_list.append(str(int(b)+2)+'_'+str(int(a)+2))
+ flag = True
+ elif image[int(b),int(a)-3,:][1] != 255:
+ h_list.append(str(int(b))+'_'+str(int(a)-3))
+ flag = True
+
+ # if flag == False:
+ # cc = image[int(b),int(a),:][1]
+
+ # cv2.circle(image, (465,505), 2, (0, 255, 0), -1)
+
+ # cv2.imshow('win', image)
+ # cv2.waitKey(0)
+
+ with open('pid.txt', 'w') as f:
+ for h in h_list:
+ pid = map_dict[h]
+ s = str(pid) + '\n'
+ f.write(s)
+
+ np_colors[pid,:] = np.array([0, 255, 0])
+
+ f.close()
+
+ pcd0 = o3d.geometry.PointCloud()
+ pcd0.points = o3d.utility.Vector3dVector(np_points)
+ pcd0.colors = o3d.utility.Vector3dVector(np_colors)
+ pcd0.normals = o3d.utility.Vector3dVector(np_normals)
+
+ o3d.io.write_point_cloud('aa.ply', pcd0)
+
+
+ mm = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ image3 = cv2.flip(mm, -1)
+
+ # cv2.imwrite('./rgb.png', image3)
+
+with open('./cord.txt', 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ m = line.split()
+ x = int(m[0])
+ y = int(m[1])
+
+ x = 1000 - x
+ y = 1000 - y
+
+ cv2.circle(image, (x,y), 2, (0, 255, 0), -1)
+
+ idx = map_dict[str(x)+'_'+str(y)]
+
+ a = 0
+
+# cv2.imshow("win", image)
+# cv2.waitKey(0)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ # import matplotlib.pyplot as plt
+ # plt.imshow(image)
+ # plt.show()
+
+ # save_pcd_path = os.path.join(output_folder, name[:-3]+'ply')
+ # # o3d.io.write_point_cloud(save_pcd_path, pcd)
+
+ # # render
+ # import trimesh
+ # # fuze_trimesh = trimesh.load('/home/gyalex/Desktop/PointImage32.obj')
+ # # mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)
+ # mesh = pyrender.Mesh.from_points(np_points, np_colors_ori, np_normals)
+
+ # import math
+ # camera = PerspectiveCamera(yfov=math.pi / 3, aspectRatio=1.0)
+ # camera_pose = np.array([[-1.0, 0.0, 0.0, 0], \
+ # [0.0, 1.0, 0.0, 0], \
+ # [0.0, 0.0, -1.0, 0], \
+ # [0.0, 0.0, 0.0, 1.0]])
+
+ # # 创建场景
+ # scene = pyrender.Scene()
+ # scene.add(mesh)
+ # scene.add(camera, pose=camera_pose)
+
+ # # light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, innerConeAngle=np.pi/16.0, outerConeAngle=np.pi/6.0)
+ # # scene.add(light, pose=camera_pose)
+
+ # # 渲染场景
+ # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
+ # color, depth = renderer.render(scene)
+
+ # # # 设置场景和光源
+ # # scene = pyrender.Scene()
+ # # scene.add(point_cloud_mesh, 'point_cloud')
+ # # camera = PerspectiveCamera(yfov=45.0, aspectRatio=1.0)
+ # # scene.add(camera)
+
+ # # # 渲染场景
+ # # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
+ # # color, depth = renderer.render(scene)
+
+ # # 保存渲染结果为图片
+ # import cv2
+ # cv2.imshow('win', color)
+
+ # rgb_img = cv2.imread('/media/gyalex/Data/face_det_dataset/rgbd_data/color_32.bmp')
+ # cv2.imshow('win0', rgb_img)
+ # cv2.waitKey(0)
\ No newline at end of file
diff --git a/external/landmark_detection/evaluate.py b/external/landmark_detection/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a7af2d727cfc0efdd42249106bd3f8cf561e2f9
--- /dev/null
+++ b/external/landmark_detection/evaluate.py
@@ -0,0 +1,258 @@
+import os
+import cv2
+import math
+import argparse
+import numpy as np
+from tqdm import tqdm
+
+import torch
+
+# private package
+from lib import utility
+
+
+
+class GetCropMatrix():
+ """
+ from_shape -> transform_matrix
+ """
+
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size - 1, self.image_size - 1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w / 2.0, to_h / 2.0])
+ return matrix
+
+
+class TransformPerspective():
+ """
+ image, matrix3x3 -> transformed_image
+ """
+
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def process(self, image, matrix):
+ return cv2.warpPerspective(
+ image, matrix, dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+
+class TransformPoints2D():
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+
+ def process(self, srcPoints, matrix):
+ # nx3
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
+ desPoints = desPoints @ np.transpose(matrix) # nx3
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
+ return desPoints.astype(srcPoints.dtype)
+
+
+class Alignment:
+ def __init__(self, args, model_path, dl_framework, device_ids):
+ self.input_size = 256
+ self.target_face_scale = 1.0
+ self.dl_framework = dl_framework
+
+ # model
+ if self.dl_framework == "pytorch":
+ # conf
+ self.config = utility.get_config(args)
+ self.config.device_id = device_ids[0]
+ # set environment
+ utility.set_environment(self.config)
+ self.config.init_instance()
+ if self.config.logger is not None:
+ self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
+ self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
+
+ net = utility.get_net(self.config)
+ if device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+ net.load_state_dict(checkpoint["net"])
+ net = net.to(self.config.device_id)
+ net.eval()
+ self.alignment = net
+ else:
+ assert False
+
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
+ align_corners=True)
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
+ self.transformPoints2D = TransformPoints2D()
+
+ def norm_points(self, points, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
+
+ def denorm_points(self, points, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
+
+ def preprocess(self, image, scale, center_w, center_h):
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
+ input_tensor = self.transformPerspective.process(image, matrix)
+ input_tensor = input_tensor[np.newaxis, :]
+
+ input_tensor = torch.from_numpy(input_tensor)
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
+ input_tensor = input_tensor.to(self.config.device_id)
+ return input_tensor, matrix
+
+ def postprocess(self, srcPoints, coeff):
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
+ # matrix^(-1) * src = dst
+ # src = matrix * dst
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
+ for i in range(srcPoints.shape[0]):
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
+ return dstPoints
+
+ def analyze(self, image, scale, center_w, center_h):
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
+
+ if self.dl_framework == "pytorch":
+ with torch.no_grad():
+ output = self.alignment(input_tensor)
+ landmarks = output[-1][0]
+ else:
+ assert False
+
+ landmarks = self.denorm_points(landmarks)
+ landmarks = landmarks.data.cpu().numpy()[0]
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
+
+ return landmarks
+
+
+def L2(p1, p2):
+ return np.linalg.norm(p1 - p2)
+
+
+def NME(landmarks_gt, landmarks_pv):
+ pts_num = landmarks_gt.shape[0]
+ if pts_num == 29:
+ left_index = 16
+ right_index = 17
+ elif pts_num == 68:
+ left_index = 36
+ right_index = 45
+ elif pts_num == 98:
+ left_index = 60
+ right_index = 72
+
+ nme = 0
+ eye_span = L2(landmarks_gt[left_index], landmarks_gt[right_index])
+ for i in range(pts_num):
+ error = L2(landmarks_pv[i], landmarks_gt[i])
+ nme += error / eye_span
+ nme /= pts_num
+ return nme
+
+
+def evaluate(args, model_path, metadata_path, device_ids, mode):
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
+ config = alignment.config
+ nme_sum = 0
+ with open(metadata_path, 'r') as f:
+ lines = f.readlines()
+ for k, line in enumerate(tqdm(lines)):
+ item = line.strip().split("\t")
+ image_name, landmarks_5pts, landmarks_gt, scale, center_w, center_h = item[:6]
+ # image & keypoints alignment
+ image_name = image_name.replace('\\', '/')
+ image_name = image_name.replace('//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
+ image_name = image_name.replace('./rawImages/', '')
+ image_path = os.path.join(config.image_dir, image_name)
+ landmarks_gt = np.array(list(map(float, landmarks_gt.split(","))), dtype=np.float32).reshape(-1, 2)
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
+
+ image = cv2.imread(image_path)
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
+
+ # NME
+ if mode == "nme":
+ nme = NME(landmarks_gt, landmarks_pv)
+ nme_sum += nme
+ # print("Current NME(%d): %f" % (k + 1, (nme_sum / (k + 1))))
+ else:
+ pass
+
+ if mode == "nme":
+ print("Final NME: %f" % (100*nme_sum / (k + 1)))
+ else:
+ pass
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Evaluation script")
+ parser.add_argument("--config_name", type=str, default="alignment", help="set configure file name")
+ parser.add_argument("--model_path", type=str, default="./train.pkl", help="the path of model")
+ parser.add_argument("--data_definition", type=str, default='WFLW', help="COFW/300W/WFLW")
+ parser.add_argument("--metadata_path", type=str, default="", help="the path of metadata")
+ parser.add_argument("--image_dir", type=str, default="", help="the path of image")
+ parser.add_argument("--device_ids", type=str, default="0", help="set device ids, -1 means use cpu device, >= 0 means use gpu device")
+ parser.add_argument("--mode", type=str, default="nme", help="set the evaluate mode: nme")
+ args = parser.parse_args()
+
+ device_ids = list(map(int, args.device_ids.split(",")))
+ evaluate(
+ args,
+ model_path=args.model_path,
+ metadata_path=args.metadata_path,
+ device_ids=device_ids,
+ mode=args.mode)
diff --git a/external/landmark_detection/infer_folder.py b/external/landmark_detection/infer_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34c75d99934e5d56d42a1b23c91a025f31cf35d
--- /dev/null
+++ b/external/landmark_detection/infer_folder.py
@@ -0,0 +1,253 @@
+import cv2
+import math
+import copy
+import numpy as np
+import argparse
+import torch
+import json
+
+# private package
+from lib import utility
+from FaceBoxesV2.faceboxes_detector import *
+
+class GetCropMatrix():
+ """
+ from_shape -> transform_matrix
+ """
+
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size - 1, self.image_size - 1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w / 2.0, to_h / 2.0])
+ return matrix
+
+
+class TransformPerspective():
+ """
+ image, matrix3x3 -> transformed_image
+ """
+
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def process(self, image, matrix):
+ return cv2.warpPerspective(
+ image, matrix, dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+
+class TransformPoints2D():
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+
+ def process(self, srcPoints, matrix):
+ # nx3
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
+ desPoints = desPoints @ np.transpose(matrix) # nx3
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
+ return desPoints.astype(srcPoints.dtype)
+
+class Alignment:
+ def __init__(self, args, model_path, dl_framework, device_ids):
+ self.input_size = 256
+ self.target_face_scale = 1.0
+ self.dl_framework = dl_framework
+
+ # model
+ if self.dl_framework == "pytorch":
+ # conf
+ self.config = utility.get_config(args)
+ self.config.device_id = device_ids[0]
+ # set environment
+ utility.set_environment(self.config)
+ # self.config.init_instance()
+ # if self.config.logger is not None:
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
+
+ net = utility.get_net(self.config)
+ if device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+ net.load_state_dict(checkpoint["net"])
+
+ if self.config.device_id == -1:
+ net = net.cpu()
+ else:
+ net = net.to(self.config.device_id)
+
+ net.eval()
+ self.alignment = net
+ else:
+ assert False
+
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
+ align_corners=True)
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
+ self.transformPoints2D = TransformPoints2D()
+
+ def norm_points(self, points, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
+
+ def denorm_points(self, points, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
+
+ def preprocess(self, image, scale, center_w, center_h):
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
+ input_tensor = self.transformPerspective.process(image, matrix)
+ input_tensor = input_tensor[np.newaxis, :]
+
+ input_tensor = torch.from_numpy(input_tensor)
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
+
+ if self.config.device_id == -1:
+ input_tensor = input_tensor.cpu()
+ else:
+ input_tensor = input_tensor.to(self.config.device_id)
+
+ return input_tensor, matrix
+
+ def postprocess(self, srcPoints, coeff):
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
+ # matrix^(-1) * src = dst
+ # src = matrix * dst
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
+ for i in range(srcPoints.shape[0]):
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
+ return dstPoints
+
+ def analyze(self, image, scale, center_w, center_h):
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
+
+ if self.dl_framework == "pytorch":
+ with torch.no_grad():
+ output = self.alignment(input_tensor)
+ landmarks = output[-1][0]
+ else:
+ assert False
+
+ landmarks = self.denorm_points(landmarks)
+ landmarks = landmarks.data.cpu().numpy()[0]
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
+
+ return landmarks
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="inference script")
+ parser.add_argument('--folder_path', type=str, help='Path to image folder')
+ args = parser.parse_args()
+
+ # args.folder_path = '/media/gyalex/Data/flame/ph_test/head_images/flame/image'
+
+ current_path = os.getcwd()
+
+ use_gpu = True
+ ########### face detection ############
+ if use_gpu:
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+
+ current_path = os.getcwd()
+ det_model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'FaceBoxesV2/weights/FaceBoxesV2.pth')
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
+
+ ########### facial alignment ############
+ model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'weights/68_keypoints_model.pkl')
+
+ if use_gpu:
+ device_ids = [0]
+ else:
+ device_ids = [-1]
+
+ args.config_name = 'alignment'
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
+
+ img_path_list = os.listdir(args.folder_path)
+ kpts_code = dict()
+
+ ########### inference ############
+ for file_name in img_path_list:
+ abs_path = os.path.join(args.folder_path, file_name)
+
+ image = cv2.imread(abs_path)
+ image_draw = copy.deepcopy(image)
+
+ detections, _ = detector.detect(image, 0.6, 1)
+ for idx in range(len(detections)):
+ x1_ori = detections[idx][2]
+ y1_ori = detections[idx][3]
+ x2_ori = x1_ori + detections[idx][4]
+ y2_ori = y1_ori + detections[idx][5]
+
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
+ center_w = (x1_ori + x2_ori) / 2
+ center_h = (y1_ori + y2_ori) / 2
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
+
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
+ landmarks_pv_list = landmarks_pv.tolist()
+
+ for num in range(landmarks_pv.shape[0]):
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
+ 2, (0, 255, 0), -1)
+
+ kpts_code[file_name] = landmarks_pv_list
+ save_path = args.folder_path[:-5] + 'landmark'
+ cv2.imwrite(os.path.join(save_path, file_name), image_draw)
+
+ path = args.folder_path[:-5]
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
diff --git a/external/landmark_detection/infer_image.py b/external/landmark_detection/infer_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e42a12bbb9b6bf497ed17d34afc588e7047430
--- /dev/null
+++ b/external/landmark_detection/infer_image.py
@@ -0,0 +1,251 @@
+import cv2
+import math
+import copy
+import numpy as np
+import argparse
+import torch
+
+# private package
+from external.landmark_detection.lib import utility
+from external.landmark_detection.FaceBoxesV2.faceboxes_detector import *
+
+class GetCropMatrix():
+ """
+ from_shape -> transform_matrix
+ """
+
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size - 1, self.image_size - 1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w / 2.0, to_h / 2.0])
+ return matrix
+
+
+class TransformPerspective():
+ """
+ image, matrix3x3 -> transformed_image
+ """
+
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def process(self, image, matrix):
+ return cv2.warpPerspective(
+ image, matrix, dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+
+class TransformPoints2D():
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+
+ def process(self, srcPoints, matrix):
+ # nx3
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
+ desPoints = desPoints @ np.transpose(matrix) # nx3
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
+ return desPoints.astype(srcPoints.dtype)
+
+class Alignment:
+ def __init__(self, args, model_path, dl_framework, device_ids):
+ self.input_size = 256
+ self.target_face_scale = 1.0
+ self.dl_framework = dl_framework
+
+ # model
+ if self.dl_framework == "pytorch":
+ # conf
+ self.config = utility.get_config(args)
+ self.config.device_id = device_ids[0]
+ # set environment
+ # utility.set_environment(self.config)
+ # self.config.init_instance()
+ # if self.config.logger is not None:
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
+
+ net = utility.get_net(self.config)
+ if device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+ net.load_state_dict(checkpoint["net"])
+
+ if self.config.device_id == -1:
+ net = net.cpu()
+ else:
+ net = net.to(self.config.device_id)
+
+ net.eval()
+ self.alignment = net
+ else:
+ assert False
+
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
+ align_corners=True)
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
+ self.transformPoints2D = TransformPoints2D()
+
+ def norm_points(self, points, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
+
+ def denorm_points(self, points, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
+
+ def preprocess(self, image, scale, center_w, center_h):
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
+ input_tensor = self.transformPerspective.process(image, matrix)
+ input_tensor = input_tensor[np.newaxis, :]
+
+ input_tensor = torch.from_numpy(input_tensor)
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
+
+ if self.config.device_id == -1:
+ input_tensor = input_tensor.cpu()
+ else:
+ input_tensor = input_tensor.to(self.config.device_id)
+
+ return input_tensor, matrix
+
+ def postprocess(self, srcPoints, coeff):
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
+ # matrix^(-1) * src = dst
+ # src = matrix * dst
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
+ for i in range(srcPoints.shape[0]):
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
+ return dstPoints
+
+ def analyze(self, image, scale, center_w, center_h):
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
+
+ if self.dl_framework == "pytorch":
+ with torch.no_grad():
+ output = self.alignment(input_tensor)
+ landmarks = output[-1][0]
+ else:
+ assert False
+
+ landmarks = self.denorm_points(landmarks)
+ landmarks = landmarks.data.cpu().numpy()[0]
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
+
+ return landmarks
+
+# parser = argparse.ArgumentParser(description="Evaluation script")
+# args = parser.parse_args()
+# image_path = './rgb.png'
+# image = cv2.imread(image_path)
+#
+# use_gpu = False
+# ########### face detection ############
+# if use_gpu:
+# device = torch.device("cuda:0")
+# else:
+# device = torch.device("cpu")
+#
+# detector = FaceBoxesDetector('FaceBoxes', 'FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
+#
+# ########### facial alignment ############
+# model_path = './weights/68_keypoints_model.pkl'
+#
+# if use_gpu:
+# device_ids = [0]
+# else:
+# device_ids = [-1]
+#
+# args.config_name = 'alignment'
+# alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
+# image_draw = copy.deepcopy(image)
+#
+# ########### inference ############
+# ldk_list = []
+#
+# detections, _ = detector.detect(image, 0.9, 1)
+# for idx in range(len(detections)):
+# x1_ori = detections[idx][2]
+# y1_ori = detections[idx][3]
+# x2_ori = x1_ori + detections[idx][4]
+# y2_ori = y1_ori + detections[idx][5]
+#
+# scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
+# center_w = (x1_ori + x2_ori) / 2
+# center_h = (y1_ori + y2_ori) / 2
+# scale, center_w, center_h = float(scale), float(center_w), float(center_h)
+#
+# landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
+#
+# for num in range(landmarks_pv.shape[0]):
+# cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
+# 2, (0, 255, 0), -1)
+#
+# ldk_list.append([round(landmarks_pv[num][0]), round(landmarks_pv[num][1])])
+#
+# cv2.imshow("win", image_draw)
+#
+# # ldk_img = cv2.imread('/home/gyalex/Desktop/image_landmark_149/all.jpg')
+# # cv2.imshow("win1", ldk_img)
+#
+# cv2.waitKey(0)
+#
+# with open('./cord.txt', 'w') as f:
+# for num in range(len(ldk_list)):
+# s = str(ldk_list[num][0]) + ' ' + str(ldk_list[num][1]) + '\n'
+# f.write(s)
+#
+# f.close()
+
+
+
diff --git a/external/landmark_detection/infer_video.py b/external/landmark_detection/infer_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..4232c200f372204445f83274bd8bdd9dbb503190
--- /dev/null
+++ b/external/landmark_detection/infer_video.py
@@ -0,0 +1,287 @@
+import cv2
+import math
+import copy
+import numpy as np
+import argparse
+import torch
+import json
+
+# private package
+from lib import utility
+from FaceBoxesV2.faceboxes_detector import *
+
+class GetCropMatrix():
+ """
+ from_shape -> transform_matrix
+ """
+
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size - 1, self.image_size - 1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w / 2.0, to_h / 2.0])
+ return matrix
+
+
+class TransformPerspective():
+ """
+ image, matrix3x3 -> transformed_image
+ """
+
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def process(self, image, matrix):
+ return cv2.warpPerspective(
+ image, matrix, dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+
+class TransformPoints2D():
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+
+ def process(self, srcPoints, matrix):
+ # nx3
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
+ desPoints = desPoints @ np.transpose(matrix) # nx3
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
+ return desPoints.astype(srcPoints.dtype)
+
+class Alignment:
+ def __init__(self, args, model_path, dl_framework, device_ids):
+ self.input_size = 256
+ self.target_face_scale = 1.0
+ self.dl_framework = dl_framework
+
+ # model
+ if self.dl_framework == "pytorch":
+ # conf
+ self.config = utility.get_config(args)
+ self.config.device_id = device_ids[0]
+ # set environment
+ utility.set_environment(self.config)
+ # self.config.init_instance()
+ # if self.config.logger is not None:
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
+
+ net = utility.get_net(self.config)
+ if device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+ net.load_state_dict(checkpoint["net"])
+
+ if self.config.device_id == -1:
+ net = net.cpu()
+ else:
+ net = net.to(self.config.device_id)
+
+ net.eval()
+ self.alignment = net
+ else:
+ assert False
+
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
+ align_corners=True)
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
+ self.transformPoints2D = TransformPoints2D()
+
+ def norm_points(self, points, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
+
+ def denorm_points(self, points, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
+
+ def preprocess(self, image, scale, center_w, center_h):
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
+ input_tensor = self.transformPerspective.process(image, matrix)
+ input_tensor = input_tensor[np.newaxis, :]
+
+ input_tensor = torch.from_numpy(input_tensor)
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
+
+ if self.config.device_id == -1:
+ input_tensor = input_tensor.cpu()
+ else:
+ input_tensor = input_tensor.to(self.config.device_id)
+
+ return input_tensor, matrix
+
+ def postprocess(self, srcPoints, coeff):
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
+ # matrix^(-1) * src = dst
+ # src = matrix * dst
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
+ for i in range(srcPoints.shape[0]):
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
+ return dstPoints
+
+ def analyze(self, image, scale, center_w, center_h):
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
+
+ if self.dl_framework == "pytorch":
+ with torch.no_grad():
+ output = self.alignment(input_tensor)
+ landmarks = output[-1][0]
+ else:
+ assert False
+
+ landmarks = self.denorm_points(landmarks)
+ landmarks = landmarks.data.cpu().numpy()[0]
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
+
+ return landmarks
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="inference script")
+ parser.add_argument('--video_path', type=str, help='Path to videos',default='/media/yuanzhen/HH/DATASET/VFTH/TESTVIDEO/Clip+7CzHzeeVRlE+P0+C0+F101007-101139.mp4')
+ args = parser.parse_args()
+
+ # args.video_path = '/media/gyalex/Data/flame/ph_test/test.mp4'
+
+ current_path = os.getcwd()
+
+ use_gpu = True
+ ########### face detection ############
+ if use_gpu:
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+
+ current_path = os.getcwd()
+ det_model_path = '/home/yuanzhen/code/landmark_detection/FaceBoxesV2/weights/FaceBoxesV2.pth'
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
+
+ ########### facial alignment ############
+ model_path = '/home/yuanzhen/code/landmark_detection/weights/68_keypoints_model.pkl'
+
+ if use_gpu:
+ device_ids = [0]
+ else:
+ device_ids = [-1]
+
+ args.config_name = 'alignment'
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
+
+ video_file = args.video_path
+ cap = cv2.VideoCapture(video_file)
+ frame_width = int(cap.get(3))
+ frame_height = int(cap.get(4))
+
+ # out_video_file = './output_video.mp4'
+ # fps = 30
+ # size = (frame_width, frame_height)
+ # out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
+
+ count = 0
+ kpts_code = dict()
+
+ keypoint_data_path = args.video_path.replace('.mp4','.json')
+ with open(keypoint_data_path,'r') as f:
+ keypoint_data = json.load(f)
+
+ ########### inference ############
+ path = video_file[:-4]
+ while(cap.isOpened()):
+ ret, image = cap.read()
+
+ if ret:
+ detections, _ = detector.detect(image, 0.8, 1)
+ image_draw = copy.deepcopy(image)
+
+ cv2.imwrite(os.path.join(path, 'image', str(count+1)+'.png'), image_draw)
+
+ for idx in range(len(detections)):
+ x1_ori = detections[idx][2]
+ y1_ori = detections[idx][3]
+ x2_ori = x1_ori + detections[idx][4]
+ y2_ori = y1_ori + detections[idx][5]
+
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
+ center_w = (x1_ori + x2_ori) / 2
+ center_h = (y1_ori + y2_ori) / 2
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
+
+ # landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
+ landmarks_pv = np.array(keypoint_data[str(count+1)+'.png'])
+
+ landmarks_pv_list = landmarks_pv.tolist()
+
+ for num in range(landmarks_pv.shape[0]):
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
+ 2, (0, 255, 0), -1)
+ cv2.putText(image_draw, str(num),
+ (round(landmarks_pv[num][0]) + 5, round(landmarks_pv[num][1]) + 5), # 文本位置
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)
+
+ kpts_code[str(count+1)+'.png'] = landmarks_pv_list
+ cv2.imwrite(os.path.join(path, 'landmark', str(count+1)+'.png'), image_draw)
+ else:
+ break
+
+ count += 1
+
+ cap.release()
+ # out.release()
+ # cv2.destroyAllWindows()
+
+ path = video_file[:-4]
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
+
+ print(path)
+
+
+
diff --git a/external/landmark_detection/lib/__init__.py b/external/landmark_detection/lib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca232a42a2f935059a05007b53df3c5e922569b
--- /dev/null
+++ b/external/landmark_detection/lib/__init__.py
@@ -0,0 +1,9 @@
+from .dataset import get_encoder, get_decoder
+from .dataset import AlignmentDataset, Augmentation
+from .backbone import StackedHGNetV1
+from .metric import NME, Accuracy
+from .utils import time_print, time_string, time_for_file, time_string_short
+from .utils import convert_secs2time, convert_size2str
+
+from .utility import get_dataloader, get_config, get_net, get_criterions
+from .utility import get_optimizer, get_scheduler
diff --git a/external/landmark_detection/lib/backbone/__init__.py b/external/landmark_detection/lib/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb1578aa88a0cd75d9793c48a73638ec813a0e4a
--- /dev/null
+++ b/external/landmark_detection/lib/backbone/__init__.py
@@ -0,0 +1,5 @@
+from .stackedHGNetV1 import StackedHGNetV1
+
+__all__ = [
+ "StackedHGNetV1",
+]
\ No newline at end of file
diff --git a/external/landmark_detection/lib/backbone/core/coord_conv.py b/external/landmark_detection/lib/backbone/core/coord_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..7239421d9e9b880ddc3dd66b443d3933a6f523f3
--- /dev/null
+++ b/external/landmark_detection/lib/backbone/core/coord_conv.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn
+
+
+class AddCoordsTh(nn.Module):
+ def __init__(self, x_dim, y_dim, with_r=False, with_boundary=False):
+ super(AddCoordsTh, self).__init__()
+ self.x_dim = x_dim
+ self.y_dim = y_dim
+ self.with_r = with_r
+ self.with_boundary = with_boundary
+
+ def forward(self, input_tensor, heatmap=None):
+ """
+ input_tensor: (batch, c, x_dim, y_dim)
+ """
+ batch_size_tensor = input_tensor.shape[0]
+
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor)
+ xx_ones = xx_ones.unsqueeze(-1)
+
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
+ xx_range = xx_range.unsqueeze(1)
+
+ xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
+ xx_channel = xx_channel.unsqueeze(-1)
+
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor)
+ yy_ones = yy_ones.unsqueeze(1)
+
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
+ yy_range = yy_range.unsqueeze(-1)
+
+ yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
+ yy_channel = yy_channel.unsqueeze(-1)
+
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
+
+ xx_channel = xx_channel / (self.x_dim - 1)
+ yy_channel = yy_channel / (self.y_dim - 1)
+
+ xx_channel = xx_channel * 2 - 1
+ yy_channel = yy_channel * 2 - 1
+
+ xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
+ yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
+
+ if self.with_boundary and type(heatmap) != type(None):
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
+ 0.0, 1.0)
+
+ zero_tensor = torch.zeros_like(xx_channel).to(xx_channel)
+ xx_boundary_channel = torch.where(boundary_channel>0.05,
+ xx_channel, zero_tensor)
+ yy_boundary_channel = torch.where(boundary_channel>0.05,
+ yy_channel, zero_tensor)
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
+
+
+ if self.with_r:
+ rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
+ rr = rr / torch.max(rr)
+ ret = torch.cat([ret, rr], dim=1)
+
+ if self.with_boundary and type(heatmap) != type(None):
+ ret = torch.cat([ret, xx_boundary_channel,
+ yy_boundary_channel], dim=1)
+ return ret
+
+
+class CoordConvTh(nn.Module):
+ """CoordConv layer as in the paper."""
+ def __init__(self, x_dim, y_dim, with_r, with_boundary,
+ in_channels, out_channels, first_one=False, relu=False, bn=False, *args, **kwargs):
+ super(CoordConvTh, self).__init__()
+ self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
+ with_boundary=with_boundary)
+ in_channels += 2
+ if with_r:
+ in_channels += 1
+ if with_boundary and not first_one:
+ in_channels += 2
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, *args, **kwargs)
+ self.relu = nn.ReLU() if relu else None
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
+
+ self.with_boundary = with_boundary
+ self.first_one = first_one
+
+
+ def forward(self, input_tensor, heatmap=None):
+ assert (self.with_boundary and not self.first_one) == (heatmap is not None)
+ ret = self.addcoords(input_tensor, heatmap)
+ ret = self.conv(ret)
+ if self.bn is not None:
+ ret = self.bn(ret)
+ if self.relu is not None:
+ ret = self.relu(ret)
+
+ return ret
+
+
+'''
+An alternative implementation for PyTorch with auto-infering the x-y dimensions.
+'''
+class AddCoords(nn.Module):
+
+ def __init__(self, with_r=False):
+ super().__init__()
+ self.with_r = with_r
+
+ def forward(self, input_tensor):
+ """
+ Args:
+ input_tensor: shape(batch, channel, x_dim, y_dim)
+ """
+ batch_size, _, x_dim, y_dim = input_tensor.size()
+
+ xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1).to(input_tensor)
+ yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2).to(input_tensor)
+
+ xx_channel = xx_channel / (x_dim - 1)
+ yy_channel = yy_channel / (y_dim - 1)
+
+ xx_channel = xx_channel * 2 - 1
+ yy_channel = yy_channel * 2 - 1
+
+ xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
+ yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
+
+ ret = torch.cat([
+ input_tensor,
+ xx_channel.type_as(input_tensor),
+ yy_channel.type_as(input_tensor)], dim=1)
+
+ if self.with_r:
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
+ ret = torch.cat([ret, rr], dim=1)
+
+ return ret
+
+
+class CoordConv(nn.Module):
+
+ def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
+ super().__init__()
+ self.addcoords = AddCoords(with_r=with_r)
+ in_channels += 2
+ if with_r:
+ in_channels += 1
+ self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
+
+ def forward(self, x):
+ ret = self.addcoords(x)
+ ret = self.conv(ret)
+ return ret
diff --git a/external/landmark_detection/lib/backbone/stackedHGNetV1.py b/external/landmark_detection/lib/backbone/stackedHGNetV1.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d02b5be3e0e491f75a627ac0d2fa2d179d9f06f
--- /dev/null
+++ b/external/landmark_detection/lib/backbone/stackedHGNetV1.py
@@ -0,0 +1,307 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .core.coord_conv import CoordConvTh
+from external.landmark_detection.lib.dataset import get_decoder
+
+
+
+class Activation(nn.Module):
+ def __init__(self, kind: str = 'relu', channel=None):
+ super().__init__()
+ self.kind = kind
+
+ if '+' in kind:
+ norm_str, act_str = kind.split('+')
+ else:
+ norm_str, act_str = 'none', kind
+
+ self.norm_fn = {
+ 'in': F.instance_norm,
+ 'bn': nn.BatchNorm2d(channel),
+ 'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True),
+ 'none': None
+ }[norm_str]
+
+ self.act_fn = {
+ 'relu': F.relu,
+ 'softplus': nn.Softplus(),
+ 'exp': torch.exp,
+ 'sigmoid': torch.sigmoid,
+ 'tanh': torch.tanh,
+ 'none': None
+ }[act_str]
+
+ self.channel = channel
+
+ def forward(self, x):
+ if self.norm_fn is not None:
+ x = self.norm_fn(x)
+ if self.act_fn is not None:
+ x = self.act_fn(x)
+ return x
+
+ def extra_repr(self):
+ return f'kind={self.kind}, channel={self.channel}'
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1):
+ super(ConvBlock, self).__init__()
+ self.inp_dim = inp_dim
+ self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size,
+ stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True)
+ self.relu = None
+ self.bn = None
+ if relu:
+ self.relu = nn.ReLU()
+ if bn:
+ self.bn = nn.BatchNorm2d(out_dim)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, inp_dim, out_dim, mid_dim=None):
+ super(ResBlock, self).__init__()
+ if mid_dim is None:
+ mid_dim = out_dim // 2
+ self.relu = nn.ReLU()
+ self.bn1 = nn.BatchNorm2d(inp_dim)
+ self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False)
+ self.bn2 = nn.BatchNorm2d(mid_dim)
+ self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False)
+ self.bn3 = nn.BatchNorm2d(mid_dim)
+ self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False)
+ self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False)
+ if inp_dim == out_dim:
+ self.need_skip = False
+ else:
+ self.need_skip = True
+
+ def forward(self, x):
+ if self.need_skip:
+ residual = self.skip_layer(x)
+ else:
+ residual = x
+ out = self.bn1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out += residual
+ return out
+
+
+class Hourglass(nn.Module):
+ def __init__(self, n, f, increase=0, up_mode='nearest',
+ add_coord=False, first_one=False, x_dim=64, y_dim=64):
+ super(Hourglass, self).__init__()
+ nf = f + increase
+
+ Block = ResBlock
+
+ if add_coord:
+ self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim,
+ with_r=True, with_boundary=True,
+ relu=False, bn=False,
+ in_channels=f, out_channels=f,
+ first_one=first_one,
+ kernel_size=1,
+ stride=1, padding=0)
+ else:
+ self.coordconv = None
+ self.up1 = Block(f, f)
+
+ # Lower branch
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ self.low1 = Block(f, nf)
+ self.n = n
+ # Recursive hourglass
+ if self.n > 1:
+ self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False)
+ else:
+ self.low2 = Block(nf, nf)
+ self.low3 = Block(nf, f)
+ self.up2 = nn.Upsample(scale_factor=2, mode=up_mode)
+
+ def forward(self, x, heatmap=None):
+ if self.coordconv is not None:
+ x = self.coordconv(x, heatmap)
+ up1 = self.up1(x)
+ pool1 = self.pool1(x)
+ low1 = self.low1(pool1)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ up2 = self.up2(low3)
+ return up1 + up2
+
+
+class E2HTransform(nn.Module):
+ def __init__(self, edge_info, num_points, num_edges):
+ super().__init__()
+
+ e2h_matrix = np.zeros([num_points, num_edges])
+ for edge_id, isclosed_indices in enumerate(edge_info):
+ is_closed, indices = isclosed_indices
+ for point_id in indices:
+ e2h_matrix[point_id, edge_id] = 1
+ e2h_matrix = torch.from_numpy(e2h_matrix).float()
+
+ # pn x en x 1 x 1.
+ self.register_buffer('weight', e2h_matrix.view(
+ e2h_matrix.size(0), e2h_matrix.size(1), 1, 1))
+
+ # some keypoints are not coverred by any edges,
+ # in these cases, we must add a constant bias to their heatmap weights.
+ bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to(
+ e2h_matrix)) < 0.5).to(e2h_matrix)
+ # pn x 1.
+ self.register_buffer('bias', bias)
+
+ def forward(self, edgemaps):
+ # input: batch_size x en x hw x hh.
+ # output: batch_size x pn x hw x hh.
+ return F.conv2d(edgemaps, weight=self.weight, bias=self.bias)
+
+
+class StackedHGNetV1(nn.Module):
+ def __init__(self, config, classes_num, edge_info,
+ nstack=4, nlevels=4, in_channel=256, increase=0,
+ add_coord=True, decoder_type='default'):
+ super(StackedHGNetV1, self).__init__()
+
+ self.cfg = config
+ self.coder_type = decoder_type
+ self.decoder = get_decoder(decoder_type=decoder_type)
+ self.nstack = nstack
+ self.add_coord = add_coord
+
+ self.num_heats = classes_num[0]
+
+ if self.add_coord:
+ convBlock = CoordConvTh(x_dim=self.cfg.width, y_dim=self.cfg.height,
+ with_r=True, with_boundary=False,
+ relu=True, bn=True,
+ in_channels=3, out_channels=64,
+ kernel_size=7,
+ stride=2, padding=3)
+ else:
+ convBlock = ConvBlock(3, 64, 7, 2, bn=True, relu=True)
+
+ pool = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ Block = ResBlock
+
+ self.pre = nn.Sequential(
+ convBlock,
+ Block(64, 128),
+ pool,
+ Block(128, 128),
+ Block(128, in_channel)
+ )
+
+ self.hgs = nn.ModuleList(
+ [Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0),
+ x_dim=int(self.cfg.width / self.nstack), y_dim=int(self.cfg.height / self.nstack))
+ for _ in range(nstack)])
+
+ self.features = nn.ModuleList([
+ nn.Sequential(
+ Block(in_channel, in_channel),
+ ConvBlock(in_channel, in_channel, 1, bn=True, relu=True)
+ ) for _ in range(nstack)])
+
+ self.out_heatmaps = nn.ModuleList(
+ [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
+ for _ in range(nstack)])
+
+ if self.cfg.use_AAM:
+ self.num_edges = classes_num[1]
+ self.num_points = classes_num[2]
+
+ self.e2h_transform = E2HTransform(edge_info, self.num_points, self.num_edges)
+ self.out_edgemaps = nn.ModuleList(
+ [ConvBlock(in_channel, self.num_edges, 1, relu=False, bn=False)
+ for _ in range(nstack)])
+ self.out_pointmaps = nn.ModuleList(
+ [ConvBlock(in_channel, self.num_points, 1, relu=False, bn=False)
+ for _ in range(nstack)])
+ self.merge_edgemaps = nn.ModuleList(
+ [ConvBlock(self.num_edges, in_channel, 1, relu=False, bn=False)
+ for _ in range(nstack - 1)])
+ self.merge_pointmaps = nn.ModuleList(
+ [ConvBlock(self.num_points, in_channel, 1, relu=False, bn=False)
+ for _ in range(nstack - 1)])
+ self.edgemap_act = Activation("sigmoid", self.num_edges)
+ self.pointmap_act = Activation("sigmoid", self.num_points)
+
+ self.merge_features = nn.ModuleList(
+ [ConvBlock(in_channel, in_channel, 1, relu=False, bn=False)
+ for _ in range(nstack - 1)])
+ self.merge_heatmaps = nn.ModuleList(
+ [ConvBlock(self.num_heats, in_channel, 1, relu=False, bn=False)
+ for _ in range(nstack - 1)])
+
+ self.nstack = nstack
+
+ self.heatmap_act = Activation("in+relu", self.num_heats)
+
+ self.inference = False
+
+ def set_inference(self, inference):
+ self.inference = inference
+
+ def forward(self, x):
+ x = self.pre(x)
+
+ y, fusionmaps = [], []
+ heatmaps = None
+ for i in range(self.nstack):
+ hg = self.hgs[i](x, heatmap=heatmaps)
+ feature = self.features[i](hg)
+
+ heatmaps0 = self.out_heatmaps[i](feature)
+ heatmaps = self.heatmap_act(heatmaps0)
+
+ if self.cfg.use_AAM:
+ pointmaps0 = self.out_pointmaps[i](feature)
+ pointmaps = self.pointmap_act(pointmaps0)
+ edgemaps0 = self.out_edgemaps[i](feature)
+ edgemaps = self.edgemap_act(edgemaps0)
+ mask = self.e2h_transform(edgemaps) * pointmaps
+ fusion_heatmaps = mask * heatmaps
+ else:
+ fusion_heatmaps = heatmaps
+
+ landmarks = self.decoder.get_coords_from_heatmap(fusion_heatmaps)
+
+ if i < self.nstack - 1:
+ x = x + self.merge_features[i](feature) + \
+ self.merge_heatmaps[i](heatmaps)
+ if self.cfg.use_AAM:
+ x += self.merge_pointmaps[i](pointmaps)
+ x += self.merge_edgemaps[i](edgemaps)
+
+ y.append(landmarks)
+ if self.cfg.use_AAM:
+ y.append(pointmaps)
+ y.append(edgemaps)
+
+ fusionmaps.append(fusion_heatmaps)
+
+ return y, fusionmaps, landmarks
\ No newline at end of file
diff --git a/external/landmark_detection/lib/dataset/__init__.py b/external/landmark_detection/lib/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aded5c82fa3e59e29cc2a734eb4bc1dc57d0ed59
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/__init__.py
@@ -0,0 +1,11 @@
+from .encoder import get_encoder
+from .decoder import get_decoder
+from .augmentation import Augmentation
+from .alignmentDataset import AlignmentDataset
+
+__all__ = [
+ "Augmentation",
+ "AlignmentDataset",
+ "get_encoder",
+ "get_decoder"
+]
diff --git a/external/landmark_detection/lib/dataset/alignmentDataset.py b/external/landmark_detection/lib/dataset/alignmentDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..70210ad37a418f92e806e48d2510f9cc41fd10f6
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/alignmentDataset.py
@@ -0,0 +1,316 @@
+import os
+import sys
+import cv2
+import math
+import copy
+import hashlib
+import imageio
+import numpy as np
+import pandas as pd
+from scipy import interpolate
+from PIL import Image, ImageEnhance, ImageFile
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+sys.path.append("./")
+from external.landmark_detection.lib.dataset.augmentation import Augmentation
+from external.landmark_detection.lib.dataset.encoder import get_encoder
+
+
+class AlignmentDataset(Dataset):
+
+ def __init__(self, tsv_flie, image_dir="", transform=None,
+ width=256, height=256, channels=3,
+ means=(127.5, 127.5, 127.5), scale=1 / 127.5,
+ classes_num=None, crop_op=True, aug_prob=0.0, edge_info=None, flip_mapping=None, is_train=True,
+ encoder_type='default',
+ ):
+ super(AlignmentDataset, self).__init__()
+ self.use_AAM = True
+ self.encoder_type = encoder_type
+ self.encoder = get_encoder(height, width, encoder_type=encoder_type)
+ self.items = pd.read_csv(tsv_flie, sep="\t")
+ self.image_dir = image_dir
+ self.landmark_num = classes_num[0]
+ self.transform = transform
+
+ self.image_width = width
+ self.image_height = height
+ self.channels = channels
+ assert self.image_width == self.image_height
+
+ self.means = means
+ self.scale = scale
+
+ self.aug_prob = aug_prob
+ self.edge_info = edge_info
+ self.is_train = is_train
+ std_lmk_5pts = np.array([
+ 196.0, 226.0,
+ 316.0, 226.0,
+ 256.0, 286.0,
+ 220.0, 360.4,
+ 292.0, 360.4], np.float32) / 256.0 - 1.0
+ std_lmk_5pts = np.reshape(std_lmk_5pts, (5, 2)) # [-1 1]
+ target_face_scale = 1.0 if crop_op else 1.25
+
+ self.augmentation = Augmentation(
+ is_train=self.is_train,
+ aug_prob=self.aug_prob,
+ image_size=self.image_width,
+ crop_op=crop_op,
+ std_lmk_5pts=std_lmk_5pts,
+ target_face_scale=target_face_scale,
+ flip_rate=0.5,
+ flip_mapping=flip_mapping,
+ random_shift_sigma=0.05,
+ random_rot_sigma=math.pi / 180 * 18,
+ random_scale_sigma=0.1,
+ random_gray_rate=0.2,
+ random_occ_rate=0.4,
+ random_blur_rate=0.3,
+ random_gamma_rate=0.2,
+ random_nose_fusion_rate=0.2)
+
+ def _circle(self, img, pt, sigma=1.0, label_type='Gaussian'):
+ # Check that any part of the gaussian is in-bounds
+ tmp_size = sigma * 3
+ ul = [int(pt[0] - tmp_size), int(pt[1] - tmp_size)]
+ br = [int(pt[0] + tmp_size + 1), int(pt[1] + tmp_size + 1)]
+ if (ul[0] > img.shape[1] - 1 or ul[1] > img.shape[0] - 1 or
+ br[0] - 1 < 0 or br[1] - 1 < 0):
+ # If not, just return the image as is
+ return img
+
+ # Generate gaussian
+ size = 2 * tmp_size + 1
+ x = np.arange(0, size, 1, np.float32)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if label_type == 'Gaussian':
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ else:
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = 255 * g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+ return img
+
+ def _polylines(self, img, lmks, is_closed, color=255, thickness=1, draw_mode=cv2.LINE_AA,
+ interpolate_mode=cv2.INTER_AREA, scale=4):
+ h, w = img.shape
+ img_scale = cv2.resize(img, (w * scale, h * scale), interpolation=interpolate_mode)
+ lmks_scale = (lmks * scale + 0.5).astype(np.int32)
+ cv2.polylines(img_scale, [lmks_scale], is_closed, color, thickness * scale, draw_mode)
+ img = cv2.resize(img_scale, (w, h), interpolation=interpolate_mode)
+ return img
+
+ def _generate_edgemap(self, points, scale=0.25, thickness=1):
+ h, w = self.image_height, self.image_width
+ edgemaps = []
+ for is_closed, indices in self.edge_info:
+ edgemap = np.zeros([h, w], dtype=np.float32)
+ # align_corners: False.
+ part = copy.deepcopy(points[np.array(indices)])
+
+ part = self._fit_curve(part, is_closed)
+ part[:, 0] = np.clip(part[:, 0], 0, w - 1)
+ part[:, 1] = np.clip(part[:, 1], 0, h - 1)
+ edgemap = self._polylines(edgemap, part, is_closed, 255, thickness)
+
+ edgemaps.append(edgemap)
+ edgemaps = np.stack(edgemaps, axis=0) / 255.0
+ edgemaps = torch.from_numpy(edgemaps).float().unsqueeze(0)
+ edgemaps = F.interpolate(edgemaps, size=(int(w * scale), int(h * scale)), mode='bilinear',
+ align_corners=False).squeeze()
+ return edgemaps
+
+ def _fit_curve(self, lmks, is_closed=False, density=5):
+ try:
+ x = lmks[:, 0].copy()
+ y = lmks[:, 1].copy()
+ if is_closed:
+ x = np.append(x, x[0])
+ y = np.append(y, y[0])
+ tck, u = interpolate.splprep([x, y], s=0, per=is_closed, k=3)
+ # bins = (x.shape[0] - 1) * density + 1
+ # lmk_x, lmk_y = interpolate.splev(np.linspace(0, 1, bins), f)
+ intervals = np.array([])
+ for i in range(len(u) - 1):
+ intervals = np.concatenate((intervals, np.linspace(u[i], u[i + 1], density, endpoint=False)))
+ if not is_closed:
+ intervals = np.concatenate((intervals, [u[-1]]))
+ lmk_x, lmk_y = interpolate.splev(intervals, tck, der=0)
+ # der_x, der_y = interpolate.splev(intervals, tck, der=1)
+ curve_lmks = np.stack([lmk_x, lmk_y], axis=-1)
+ # curve_ders = np.stack([der_x, der_y], axis=-1)
+ # origin_indices = np.arange(0, curve_lmks.shape[0], density)
+
+ return curve_lmks
+ except:
+ return lmks
+
+ def _image_id(self, image_path):
+ if not os.path.exists(image_path):
+ image_path = os.path.join(self.image_dir, image_path)
+ return hashlib.md5(open(image_path, "rb").read()).hexdigest()
+
+ def _load_image(self, image_path):
+ if not os.path.exists(image_path):
+ image_path = os.path.join(self.image_dir, image_path)
+
+ try:
+ # img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)#HWC, BGR, [0-255]
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR) # HWC, BGR, [0-255]
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
+ except:
+ try:
+ img = imageio.imread(image_path) # HWC, RGB, [0-255]
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # HWC, BGR, [0-255]
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
+ except:
+ try:
+ gifImg = imageio.mimread(image_path) # BHWC, RGB, [0-255]
+ img = gifImg[0] # HWC, RGB, [0-255]
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # HWC, BGR, [0-255]
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
+ except:
+ img = None
+ return img
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def _transformPoints2D(self, points, matrix):
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+ dtype = points.dtype
+
+ # nx3
+ points = np.concatenate([points, np.ones_like(points[:, [0]])], axis=1)
+ points = points @ np.transpose(matrix) # nx3
+ points = points[:, :2] / points[:, [2, 2]]
+ return points.astype(dtype)
+
+ def _transformPerspective(self, image, matrix, target_shape):
+ """
+ image, matrix3x3 -> transformed_image
+ """
+ return cv2.warpPerspective(
+ image, matrix,
+ dsize=(target_shape[1], target_shape[0]),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+ def _norm_points(self, points, h, w, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ des_points = points / torch.tensor([w - 1, h - 1]).to(points).view(1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ des_points = (points * 2 + 1) / torch.tensor([w, h]).to(points).view(1, 2) - 1
+ des_points = torch.clamp(des_points, -1, 1)
+ return des_points
+
+ def _denorm_points(self, points, h, w, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ des_points = (points + 1) / 2 * torch.tensor([w - 1, h - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ des_points = ((points + 1) * torch.tensor([w, h]).to(points).view(1, 1, 2) - 1) / 2
+ return des_points
+
+ def __len__(self):
+ return len(self.items)
+
+ def __getitem__(self, index):
+ sample = dict()
+
+ image_path = self.items.iloc[index, 0]
+ landmarks_5pts = self.items.iloc[index, 1]
+ landmarks_5pts = np.array(list(map(float, landmarks_5pts.split(","))), dtype=np.float32).reshape(5, 2)
+ landmarks_target = self.items.iloc[index, 2]
+ landmarks_target = np.array(list(map(float, landmarks_target.split(","))), dtype=np.float32).reshape(
+ self.landmark_num, 2)
+ scale = float(self.items.iloc[index, 3])
+ center_w, center_h = float(self.items.iloc[index, 4]), float(self.items.iloc[index, 5])
+ if len(self.items.iloc[index]) > 6:
+ tags = np.array(list(map(lambda x: int(float(x)), self.items.iloc[index, 6].split(","))))
+ else:
+ tags = np.array([])
+
+ # image & keypoints alignment
+ image_path = image_path.replace('\\', '/')
+ # wflw testset
+ image_path = image_path.replace(
+ '//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
+ # trainset
+ image_path = image_path.replace('./rawImages/', '')
+ image_path = os.path.join(self.image_dir, image_path)
+
+ # image path
+ sample["image_path"] = image_path
+
+ img = self._load_image(image_path) # HWC, BGR, [0, 255]
+ assert img is not None
+
+ # augmentation
+ # landmarks_target = [-0.5, edge-0.5]
+ img, landmarks_target, matrix = \
+ self.augmentation.process(img, landmarks_target, landmarks_5pts, scale, center_w, center_h)
+
+ landmarks = self._norm_points(torch.from_numpy(landmarks_target), self.image_height, self.image_width)
+
+ sample["label"] = [landmarks, ]
+
+ if self.use_AAM:
+ pointmap = self.encoder.generate_heatmap(landmarks_target)
+ edgemap = self._generate_edgemap(landmarks_target)
+ sample["label"] += [pointmap, edgemap]
+
+ sample['matrix'] = matrix
+
+ # image normalization
+ img = img.transpose(2, 0, 1).astype(np.float32) # CHW, BGR, [0, 255]
+ img[0, :, :] = (img[0, :, :] - self.means[0]) * self.scale
+ img[1, :, :] = (img[1, :, :] - self.means[1]) * self.scale
+ img[2, :, :] = (img[2, :, :] - self.means[2]) * self.scale
+ sample["data"] = torch.from_numpy(img) # CHW, BGR, [-1, 1]
+
+ sample["tags"] = tags
+
+ return sample
diff --git a/external/landmark_detection/lib/dataset/augmentation.py b/external/landmark_detection/lib/dataset/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a2761d550d88ff51286a54eaae663d3b6e14a1
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/augmentation.py
@@ -0,0 +1,355 @@
+import os
+import cv2
+import math
+import random
+import numpy as np
+from skimage import transform
+
+
+class Augmentation:
+ def __init__(self,
+ is_train=True,
+ aug_prob=1.0,
+ image_size=256,
+ crop_op=True,
+ std_lmk_5pts=None,
+ target_face_scale=1.0,
+ flip_rate=0.5,
+ flip_mapping=None,
+ random_shift_sigma=0.05,
+ random_rot_sigma=math.pi/180*18,
+ random_scale_sigma=0.1,
+ random_gray_rate=0.2,
+ random_occ_rate=0.4,
+ random_blur_rate=0.3,
+ random_gamma_rate=0.2,
+ random_nose_fusion_rate=0.2):
+ self.is_train = is_train
+ self.aug_prob = aug_prob
+ self.crop_op = crop_op
+ self._flip = Flip(flip_mapping, flip_rate)
+ if self.crop_op:
+ self._cropMatrix = GetCropMatrix(
+ image_size=image_size,
+ target_face_scale=target_face_scale,
+ align_corners=True)
+ else:
+ self._alignMatrix = GetAlignMatrix(
+ image_size=image_size,
+ target_face_scale=target_face_scale,
+ std_lmk_5pts=std_lmk_5pts)
+ self._randomGeometryMatrix = GetRandomGeometryMatrix(
+ target_shape=(image_size, image_size),
+ from_shape=(image_size, image_size),
+ shift_sigma=random_shift_sigma,
+ rot_sigma=random_rot_sigma,
+ scale_sigma=random_scale_sigma,
+ align_corners=True)
+ self._transform = Transform(image_size=image_size)
+ self._randomTexture = RandomTexture(
+ random_gray_rate=random_gray_rate,
+ random_occ_rate=random_occ_rate,
+ random_blur_rate=random_blur_rate,
+ random_gamma_rate=random_gamma_rate,
+ random_nose_fusion_rate=random_nose_fusion_rate)
+
+ def process(self, img, lmk, lmk_5pts=None, scale=1.0, center_w=0, center_h=0, is_train=True):
+ if self.is_train and random.random() < self.aug_prob:
+ img, lmk, lmk_5pts, center_w, center_h = self._flip.process(img, lmk, lmk_5pts, center_w, center_h)
+ matrix_geoaug = self._randomGeometryMatrix.process()
+ if self.crop_op:
+ matrix_pre = self._cropMatrix.process(scale, center_w, center_h)
+ else:
+ matrix_pre = self._alignMatrix.process(lmk_5pts)
+ matrix = matrix_geoaug @ matrix_pre
+ aug_img, aug_lmk = self._transform.process(img, lmk, matrix)
+ aug_img = self._randomTexture.process(aug_img)
+ else:
+ if self.crop_op:
+ matrix = self._cropMatrix.process(scale, center_w, center_h)
+ else:
+ matrix = self._alignMatrix.process(lmk_5pts)
+ aug_img, aug_lmk = self._transform.process(img, lmk, matrix)
+ return aug_img, aug_lmk, matrix
+
+
+class GetCropMatrix:
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size-1, self.image_size-1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w/2.0, to_h/2.0])
+ return matrix
+
+
+class GetAlignMatrix:
+ def __init__(self, image_size, target_face_scale, std_lmk_5pts):
+ """
+ points in std_lmk_5pts range from -1 to 1.
+ """
+ self.std_lmk_5pts = (std_lmk_5pts * target_face_scale + 1) * \
+ np.array([image_size, image_size], np.float32) / 2.0
+
+ def process(self, lmk_5pts):
+ assert lmk_5pts.shape[-2:] == (5, 2)
+ tform = transform.SimilarityTransform()
+ tform.estimate(lmk_5pts, self.std_lmk_5pts)
+ return tform.params
+
+
+class GetRandomGeometryMatrix:
+ def __init__(self, target_shape, from_shape,
+ shift_sigma=0.1, rot_sigma=18*math.pi/180, scale_sigma=0.1,
+ shift_mu=0.0, rot_mu=0.0, scale_mu=1.0,
+ shift_normal=True, rot_normal=True, scale_normal=True,
+ align_corners=False):
+ self.target_shape = target_shape
+ self.from_shape = from_shape
+ self.shift_config = (shift_mu, shift_sigma, shift_normal)
+ self.rot_config = (rot_mu, rot_sigma, rot_normal)
+ self.scale_config = (scale_mu, scale_sigma, scale_normal)
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def _random(self, mu_sigma_normal, size=None):
+ mu, sigma, is_normal = mu_sigma_normal
+ if is_normal:
+ return np.random.normal(mu, sigma, size=size)
+ else:
+ return np.random.uniform(low=mu-sigma, high=mu+sigma, size=size)
+
+ def process(self):
+ if self.align_corners:
+ from_w, from_h = self.from_shape[1]-1, self.from_shape[0]-1
+ to_w, to_h = self.target_shape[1]-1, self.target_shape[0]-1
+ else:
+ from_w, from_h = self.from_shape[1], self.from_shape[0]
+ to_w, to_h = self.target_shape[1], self.target_shape[0]
+
+ if self.shift_config[:2] != (0.0, 0.0) or \
+ self.rot_config[:2] != (0.0, 0.0) or \
+ self.scale_config[:2] != (1.0, 0.0):
+ shift_xy = self._random(self.shift_config, size=[2]) * \
+ min(to_h, to_w)
+ rot_angle = self._random(self.rot_config)
+ scale = self._random(self.scale_config)
+ matrix_geoaug = self._compose_rotate_and_scale(
+ rot_angle, scale, shift_xy,
+ from_center=[from_w/2.0, from_h/2.0],
+ to_center=[to_w/2.0, to_h/2.0])
+
+ return matrix_geoaug
+
+
+class Transform:
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def _transformPoints2D(self, points, matrix):
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+ dtype = points.dtype
+
+ # nx3
+ points = np.concatenate([points, np.ones_like(points[:, [0]])], axis=1)
+ points = points @ np.transpose(matrix)
+ points = points[:, :2] / points[:, [2, 2]]
+ return points.astype(dtype)
+
+ def _transformPerspective(self, image, matrix):
+ """
+ image, matrix3x3 -> transformed_image
+ """
+ return cv2.warpPerspective(
+ image, matrix,
+ dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+ def process(self, image, landmarks, matrix):
+ t_landmarks = self._transformPoints2D(landmarks, matrix)
+ t_image = self._transformPerspective(image, matrix)
+ return t_image, t_landmarks
+
+
+class RandomTexture:
+ def __init__(self, random_gray_rate=0, random_occ_rate=0, random_blur_rate=0, random_gamma_rate=0, random_nose_fusion_rate=0):
+ self.random_gray_rate = random_gray_rate
+ self.random_occ_rate = random_occ_rate
+ self.random_blur_rate = random_blur_rate
+ self.random_gamma_rate = random_gamma_rate
+ self.random_nose_fusion_rate = random_nose_fusion_rate
+ self.texture_augs = (
+ (self.add_occ, self.random_occ_rate),
+ (self.add_blur, self.random_blur_rate),
+ (self.add_gamma, self.random_gamma_rate),
+ (self.add_nose_fusion, self.random_nose_fusion_rate)
+ )
+
+ def add_gray(self, image):
+ assert image.ndim == 3 and image.shape[-1] == 3
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ image = np.tile(np.expand_dims(image, -1), [1, 1, 3])
+ return image
+
+ def add_occ(self, image):
+ h, w, c = image.shape
+ rh = 0.2 + 0.6 * random.random() # [0.2, 0.8]
+ rw = rh - 0.2 + 0.4 * random.random()
+ cx = int((h - 1) * random.random())
+ cy = int((w - 1) * random.random())
+ dh = int(h / 2 * rh)
+ dw = int(w / 2 * rw)
+ x0 = max(0, cx - dw // 2)
+ y0 = max(0, cy - dh // 2)
+ x1 = min(w - 1, cx + dw // 2)
+ y1 = min(h - 1, cy + dh // 2)
+ image[y0:y1+1, x0:x1+1] = 0
+ return image
+
+ def add_blur(self, image):
+ blur_kratio = 0.05 * random.random()
+ blur_ksize = int((image.shape[0] + image.shape[1]) / 2 * blur_kratio)
+ if blur_ksize > 1:
+ image = cv2.blur(image, (blur_ksize, blur_ksize))
+ return image
+
+ def add_gamma(self, image):
+ if random.random() < 0.5:
+ gamma = 0.25 + 0.75 * random.random()
+ else:
+ gamma = 1.0 + 3.0 * random.random()
+ image = (((image / 255.0) ** gamma) * 255).astype("uint8")
+ return image
+
+ def add_nose_fusion(self, image):
+ h, w, c = image.shape
+ nose = np.array(bytearray(os.urandom(h * w * c)), dtype=image.dtype).reshape(h, w, c)
+ alpha = 0.5 * random.random()
+ image = (1 - alpha) * image + alpha * nose
+ return image.astype(np.uint8)
+
+ def process(self, image):
+ image = image.copy()
+ if random.random() < self.random_occ_rate:
+ image = self.add_occ(image)
+ if random.random() < self.random_blur_rate:
+ image = self.add_blur(image)
+ if random.random() < self.random_gamma_rate:
+ image = self.add_gamma(image)
+ if random.random() < self.random_nose_fusion_rate:
+ image = self.add_nose_fusion(image)
+ """
+ orders = list(range(len(self.texture_augs)))
+ random.shuffle(orders)
+ for order in orders:
+ if random.random() < self.texture_augs[order][1]:
+ image = self.texture_augs[order][0](image)
+ """
+
+ if random.random() < self.random_gray_rate:
+ image = self.add_gray(image)
+
+ return image
+
+
+class Flip:
+ def __init__(self, flip_mapping, random_rate):
+ self.flip_mapping = flip_mapping
+ self.random_rate = random_rate
+
+ def process(self, image, landmarks, landmarks_5pts, center_w, center_h):
+ if random.random() >= self.random_rate or self.flip_mapping is None:
+ return image, landmarks, landmarks_5pts, center_w, center_h
+
+ # COFW
+ if landmarks.shape[0] == 29:
+ flip_offset = 0
+ # 300W, WFLW
+ elif landmarks.shape[0] in (68, 98):
+ flip_offset = -1
+ else:
+ flip_offset = -1
+
+ h, w, _ = image.shape
+ #image_flip = cv2.flip(image, 1)
+ image_flip = np.fliplr(image).copy()
+ landmarks_flip = landmarks.copy()
+ for i, j in self.flip_mapping:
+ landmarks_flip[i] = landmarks[j]
+ landmarks_flip[j] = landmarks[i]
+ landmarks_flip[:, 0] = w + flip_offset - landmarks_flip[:, 0]
+ if landmarks_5pts is not None:
+ flip_mapping = ([0, 1], [3, 4])
+ landmarks_5pts_flip = landmarks_5pts.copy()
+ for i, j in flip_mapping:
+ landmarks_5pts_flip[i] = landmarks_5pts[j]
+ landmarks_5pts_flip[j] = landmarks_5pts[i]
+ landmarks_5pts_flip[:, 0] = w + flip_offset - landmarks_5pts_flip[:, 0]
+ else:
+ landmarks_5pts_flip = None
+
+ center_w = w + flip_offset - center_w
+ return image_flip, landmarks_flip, landmarks_5pts_flip, center_w, center_h
diff --git a/external/landmark_detection/lib/dataset/decoder/__init__.py b/external/landmark_detection/lib/dataset/decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5d450d174bcf675bc13713f530b399a59447c5c
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/decoder/__init__.py
@@ -0,0 +1,8 @@
+from .decoder_default import decoder_default
+
+def get_decoder(decoder_type='default'):
+ if decoder_type == 'default':
+ decoder = decoder_default()
+ else:
+ raise NotImplementedError
+ return decoder
\ No newline at end of file
diff --git a/external/landmark_detection/lib/dataset/decoder/decoder_default.py b/external/landmark_detection/lib/dataset/decoder/decoder_default.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b0b4edd1651a720656cb1d49b4ae786b739bc88
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/decoder/decoder_default.py
@@ -0,0 +1,38 @@
+import torch
+
+
+class decoder_default:
+ def __init__(self, weight=1, use_weight_map=False):
+ self.weight = weight
+ self.use_weight_map = use_weight_map
+
+ def _make_grid(self, h, w):
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float() / (h - 1) * 2 - 1,
+ torch.arange(w).float() / (w - 1) * 2 - 1)
+ return yy, xx
+
+ def get_coords_from_heatmap(self, heatmap):
+ """
+ inputs:
+ - heatmap: batch x npoints x h x w
+
+ outputs:
+ - coords: batch x npoints x 2 (x,y), [-1, +1]
+ - radius_sq: batch x npoints
+ """
+ batch, npoints, h, w = heatmap.shape
+ if self.use_weight_map:
+ heatmap = heatmap * self.weight
+
+ yy, xx = self._make_grid(h, w)
+ yy = yy.view(1, 1, h, w).to(heatmap)
+ xx = xx.view(1, 1, h, w).to(heatmap)
+
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
+
+ yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
+ xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
+
+ return coords
diff --git a/external/landmark_detection/lib/dataset/encoder/__init__.py b/external/landmark_detection/lib/dataset/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d0b6f9f465e2c825dc2096ee6cd828bede52f8
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/encoder/__init__.py
@@ -0,0 +1,8 @@
+from .encoder_default import encoder_default
+
+def get_encoder(image_height, image_width, scale=0.25, sigma=1.5, encoder_type='default'):
+ if encoder_type == 'default':
+ encoder = encoder_default(image_height, image_width, scale, sigma)
+ else:
+ raise NotImplementedError
+ return encoder
diff --git a/external/landmark_detection/lib/dataset/encoder/encoder_default.py b/external/landmark_detection/lib/dataset/encoder/encoder_default.py
new file mode 100644
index 0000000000000000000000000000000000000000..92c22b131c192371dec812e88a9ecbb37cd72020
--- /dev/null
+++ b/external/landmark_detection/lib/dataset/encoder/encoder_default.py
@@ -0,0 +1,63 @@
+import copy
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+
+
+class encoder_default:
+ def __init__(self, image_height, image_width, scale=0.25, sigma=1.5):
+ self.image_height = image_height
+ self.image_width = image_width
+ self.scale = scale
+ self.sigma = sigma
+
+ def generate_heatmap(self, points):
+ # points = (num_pts, 2)
+ h, w = self.image_height, self.image_width
+ pointmaps = []
+ for i in range(len(points)):
+ pointmap = np.zeros([h, w], dtype=np.float32)
+ # align_corners: False.
+ point = copy.deepcopy(points[i])
+ point[0] = max(0, min(w - 1, point[0]))
+ point[1] = max(0, min(h - 1, point[1]))
+ pointmap = self._circle(pointmap, point, sigma=self.sigma)
+
+ pointmaps.append(pointmap)
+ pointmaps = np.stack(pointmaps, axis=0) / 255.0
+ pointmaps = torch.from_numpy(pointmaps).float().unsqueeze(0)
+ pointmaps = F.interpolate(pointmaps, size=(int(w * self.scale), int(h * self.scale)), mode='bilinear',
+ align_corners=False).squeeze()
+ return pointmaps
+
+ def _circle(self, img, pt, sigma=1.0, label_type='Gaussian'):
+ # Check that any part of the gaussian is in-bounds
+ tmp_size = sigma * 3
+ ul = [int(pt[0] - tmp_size), int(pt[1] - tmp_size)]
+ br = [int(pt[0] + tmp_size + 1), int(pt[1] + tmp_size + 1)]
+ if (ul[0] > img.shape[1] - 1 or ul[1] > img.shape[0] - 1 or
+ br[0] - 1 < 0 or br[1] - 1 < 0):
+ # If not, just return the image as is
+ return img
+
+ # Generate gaussian
+ size = 2 * tmp_size + 1
+ x = np.arange(0, size, 1, np.float32)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if label_type == 'Gaussian':
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ else:
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = 255 * g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+ return img
diff --git a/external/landmark_detection/lib/loss/__init__.py b/external/landmark_detection/lib/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b40209fb768ce78c3ff90a1dd7472710d5ad2ee
--- /dev/null
+++ b/external/landmark_detection/lib/loss/__init__.py
@@ -0,0 +1,14 @@
+from .awingLoss import AWingLoss
+from .smoothL1Loss import SmoothL1Loss
+from .wingLoss import WingLoss
+from .starLoss import STARLoss
+from .starLoss_v2 import STARLoss_v2
+
+__all__ = [
+ "AWingLoss",
+ "SmoothL1Loss",
+ "WingLoss",
+ "STARLoss",
+
+ "STARLoss_v2",
+]
diff --git a/external/landmark_detection/lib/loss/awingLoss.py b/external/landmark_detection/lib/loss/awingLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..531a6cd1616c927e2f48dece13ab6e37f384421a
--- /dev/null
+++ b/external/landmark_detection/lib/loss/awingLoss.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AWingLoss(nn.Module):
+ def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1, use_weight_map=True):
+ super(AWingLoss, self).__init__()
+ self.omega = omega
+ self.theta = theta
+ self.epsilon = epsilon
+ self.alpha = alpha
+ self.use_weight_map = use_weight_map
+
+ def __repr__(self):
+ return "AWingLoss()"
+
+ def generate_weight_map(self, heatmap, k_size=3, w=10):
+ dilate = F.max_pool2d(heatmap, kernel_size=k_size, stride=1, padding=1)
+ weight_map = torch.where(dilate < 0.2, torch.zeros_like(heatmap), torch.ones_like(heatmap))
+ return w * weight_map + 1
+
+ def forward(self, output, groundtruth):
+ """
+ input: b x n x h x w
+ output: b x n x h x w => 1
+ """
+ delta = (output - groundtruth).abs()
+ A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))) * (self.alpha - groundtruth) * \
+ (torch.pow(self.theta / self.epsilon, self.alpha - groundtruth - 1)) * (1 / self.epsilon)
+ C = self.theta * A - self.omega * \
+ torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))
+ loss = torch.where(delta < self.theta,
+ self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - groundtruth)),
+ (A * delta - C))
+ if self.use_weight_map:
+ weight = self.generate_weight_map(groundtruth)
+ loss = loss * weight
+ return loss.mean()
diff --git a/external/landmark_detection/lib/loss/smoothL1Loss.py b/external/landmark_detection/lib/loss/smoothL1Loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bac18f6d9128e294b65ba04cd25357971331c114
--- /dev/null
+++ b/external/landmark_detection/lib/loss/smoothL1Loss.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+
+
+class SmoothL1Loss(nn.Module):
+ def __init__(self, scale=0.01):
+ super(SmoothL1Loss, self).__init__()
+ self.scale = scale
+ self.EPSILON = 1e-10
+
+ def __repr__(self):
+ return "SmoothL1Loss()"
+
+ def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'):
+ """
+ input: b x n x 2
+ output: b x n x 1 => 1
+ """
+ if output.dim() == 4:
+ shape = output.shape
+ groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3])
+
+ delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False)
+ delta = delta_2.clamp(min=1e-6).sqrt()
+ # delta = torch.sqrt(delta_2 + self.EPSILON)
+ loss = torch.where( \
+ delta_2 < self.scale * self.scale, \
+ 0.5 / self.scale * delta_2, \
+ delta - 0.5 * self.scale)
+
+ if reduction == 'mean':
+ loss = loss.mean()
+ elif reduction == 'sum':
+ loss = loss.sum()
+
+ return loss
diff --git a/external/landmark_detection/lib/loss/starLoss.py b/external/landmark_detection/lib/loss/starLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbb7eb617dea0ec12c5624f26a4389ae554ef7e2
--- /dev/null
+++ b/external/landmark_detection/lib/loss/starLoss.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .smoothL1Loss import SmoothL1Loss
+from .wingLoss import WingLoss
+
+
+def get_channel_sum(input):
+ temp = torch.sum(input, dim=3)
+ output = torch.sum(temp, dim=2)
+ return output
+
+
+def expand_two_dimensions_at_end(input, dim1, dim2):
+ input = input.unsqueeze(-1).unsqueeze(-1)
+ input = input.expand(-1, -1, dim1, dim2)
+ return input
+
+
+class STARLoss(nn.Module):
+ def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5):
+ super(STARLoss, self).__init__()
+ self.w = w
+ self.num_dim_image = num_dim_image
+ self.EPSILON = EPSILON
+ self.dist = dist
+ if self.dist == 'smoothl1':
+ self.dist_func = SmoothL1Loss()
+ elif self.dist == 'l1':
+ self.dist_func = F.l1_loss
+ elif self.dist == 'l2':
+ self.dist_func = F.mse_loss
+ elif self.dist == 'wing':
+ self.dist_func = WingLoss()
+ else:
+ raise NotImplementedError
+
+ def __repr__(self):
+ return "STARLoss()"
+
+ def _make_grid(self, h, w):
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float() / (h - 1) * 2 - 1,
+ torch.arange(w).float() / (w - 1) * 2 - 1)
+ return yy, xx
+
+ def weighted_mean(self, heatmap):
+ batch, npoints, h, w = heatmap.shape
+
+ yy, xx = self._make_grid(h, w)
+ yy = yy.view(1, 1, h, w).to(heatmap)
+ xx = xx.view(1, 1, h, w).to(heatmap)
+
+ yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
+ xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
+ return coords
+
+ def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
+ batch_size, num_points, height, width = htp.shape
+
+ yv, xv = self._make_grid(height, width)
+ xv = Variable(xv)
+ yv = Variable(yv)
+
+ if htp.is_cuda:
+ xv = xv.cuda()
+ yv = yv.cuda()
+
+ xmean = means[:, :, 0]
+ xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
+ width) # [batch_size, 68, 64, 64]
+ ymean = means[:, :, 1]
+ yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
+ width) # [batch_size, 68, 64, 64]
+ wt_xv_minus_mean = xv_minus_mean
+ wt_yv_minus_mean = yv_minus_mean
+
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
+
+ htp_vec = htp.view(batch_size * num_points, 1, height * width)
+ htp_vec = htp_vec.expand(-1, 2, -1)
+
+ covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
+ covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
+
+ V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
+ V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
+
+ denominator = V_1 - (V_2 / V_1)
+ covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
+
+ return covariance
+
+ def ambiguity_guided_decompose(self, pts, eigenvalues, eigenvectors):
+ batch_size, npoints = pts.shape[:2]
+ rotate = torch.matmul(pts.view(batch_size, npoints, 1, 2), eigenvectors.transpose(-1, -2))
+ scale = rotate.view(batch_size, npoints, 2) / torch.sqrt(eigenvalues + self.EPSILON)
+ return scale
+
+ def eigenvalue_restriction(self, evalues, batch, npoints):
+ eigen_loss = torch.abs(evalues.view(batch * npoints, 2)).sum(-1)
+ return eigen_loss.mean()
+
+ def forward(self, heatmap, groundtruth):
+ """
+ heatmap: b x n x 64 x 64
+ groundtruth: b x n x 2
+ output: b x n x 1 => 1
+ """
+ # normalize
+ bs, npoints, h, w = heatmap.shape
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
+ heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1)
+
+ means = self.weighted_mean(heatmap) # [bs, 68, 2]
+ covars = self.unbiased_weighted_covariance(heatmap, means) # covars [bs, 68, 2, 2]
+
+ # TODO: GPU-based eigen-decomposition
+ # https://github.com/pytorch/pytorch/issues/60537
+ _covars = covars.view(bs * npoints, 2, 2).cpu()
+ evalues, evectors = _covars.symeig(eigenvectors=True) # evalues [bs * 68, 2], evectors [bs * 68, 2, 2]
+ evalues = evalues.view(bs, npoints, 2).to(heatmap)
+ evectors = evectors.view(bs, npoints, 2, 2).to(heatmap)
+
+ # STAR Loss
+ # Ambiguity-guided Decomposition
+ error = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors)
+ loss_trans = self.dist_func(torch.zeros_like(error).to(error), error)
+ # Eigenvalue Restriction
+ loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints)
+ star_loss = loss_trans + self.w * loss_eigen
+
+ return star_loss
diff --git a/external/landmark_detection/lib/loss/starLoss_v2.py b/external/landmark_detection/lib/loss/starLoss_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..871a70abd491afceed18214264e933b6f5f46e27
--- /dev/null
+++ b/external/landmark_detection/lib/loss/starLoss_v2.py
@@ -0,0 +1,150 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .smoothL1Loss import SmoothL1Loss
+from .wingLoss import WingLoss
+
+
+def get_channel_sum(input):
+ temp = torch.sum(input, dim=3)
+ output = torch.sum(temp, dim=2)
+ return output
+
+
+def expand_two_dimensions_at_end(input, dim1, dim2):
+ input = input.unsqueeze(-1).unsqueeze(-1)
+ input = input.expand(-1, -1, dim1, dim2)
+ return input
+
+
+class STARLoss_v2(nn.Module):
+ def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5):
+ super(STARLoss_v2, self).__init__()
+ self.w = w
+ self.num_dim_image = num_dim_image
+ self.EPSILON = EPSILON
+ self.dist = dist
+ if self.dist == 'smoothl1':
+ self.dist_func = SmoothL1Loss()
+ elif self.dist == 'l1':
+ self.dist_func = F.l1_loss
+ elif self.dist == 'l2':
+ self.dist_func = F.mse_loss
+ elif self.dist == 'wing':
+ self.dist_func = WingLoss()
+ else:
+ raise NotImplementedError
+
+ def __repr__(self):
+ return "STARLoss()"
+
+ def _make_grid(self, h, w):
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float() / (h - 1) * 2 - 1,
+ torch.arange(w).float() / (w - 1) * 2 - 1)
+ return yy, xx
+
+ def weighted_mean(self, heatmap):
+ batch, npoints, h, w = heatmap.shape
+
+ yy, xx = self._make_grid(h, w)
+ yy = yy.view(1, 1, h, w).to(heatmap)
+ xx = xx.view(1, 1, h, w).to(heatmap)
+
+ yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
+ xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
+ return coords
+
+ def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
+ batch_size, num_points, height, width = htp.shape
+
+ yv, xv = self._make_grid(height, width)
+ xv = Variable(xv)
+ yv = Variable(yv)
+
+ if htp.is_cuda:
+ xv = xv.cuda()
+ yv = yv.cuda()
+
+ xmean = means[:, :, 0]
+ xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
+ width) # [batch_size, 68, 64, 64]
+ ymean = means[:, :, 1]
+ yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
+ width) # [batch_size, 68, 64, 64]
+ wt_xv_minus_mean = xv_minus_mean
+ wt_yv_minus_mean = yv_minus_mean
+
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
+
+ htp_vec = htp.view(batch_size * num_points, 1, height * width)
+ htp_vec = htp_vec.expand(-1, 2, -1)
+
+ covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
+ covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
+
+ V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
+ V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
+
+ denominator = V_1 - (V_2 / V_1)
+ covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
+
+ return covariance
+
+ def ambiguity_guided_decompose(self, error, evalues, evectors):
+ bs, npoints = error.shape[:2]
+ normal_vector = evectors[:, :, 0]
+ tangent_vector = evectors[:, :, 1]
+ normal_error = torch.matmul(normal_vector.unsqueeze(-2), error.unsqueeze(-1))
+ tangent_error = torch.matmul(tangent_vector.unsqueeze(-2), error.unsqueeze(-1))
+ normal_error = normal_error.squeeze(dim=-1)
+ tangent_error = tangent_error.squeeze(dim=-1)
+ normal_dist = self.dist_func(normal_error, torch.zeros_like(normal_error).to(normal_error), reduction='none')
+ tangent_dist = self.dist_func(tangent_error, torch.zeros_like(tangent_error).to(tangent_error), reduction='none')
+ normal_dist = normal_dist.reshape(bs, npoints, 1)
+ tangent_dist = tangent_dist.reshape(bs, npoints, 1)
+ dist = torch.cat((normal_dist, tangent_dist), dim=-1)
+ scale_dist = dist / torch.sqrt(evalues + self.EPSILON)
+ scale_dist = scale_dist.sum(-1)
+ return scale_dist
+
+ def eigenvalue_restriction(self, evalues, batch, npoints):
+ eigen_loss = torch.abs(evalues.view(batch, npoints, 2)).sum(-1)
+ return eigen_loss
+
+ def forward(self, heatmap, groundtruth):
+ """
+ heatmap: b x n x 64 x 64
+ groundtruth: b x n x 2
+ output: b x n x 1 => 1
+ """
+ # normalize
+ bs, npoints, h, w = heatmap.shape
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
+ heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1)
+
+ means = self.weighted_mean(heatmap) # [bs, 68, 2]
+ covars = self.unbiased_weighted_covariance(heatmap, means) # covars [bs, 68, 2, 2]
+
+ # TODO: GPU-based eigen-decomposition
+ # https://github.com/pytorch/pytorch/issues/60537
+ _covars = covars.view(bs * npoints, 2, 2).cpu()
+ evalues, evectors = _covars.symeig(eigenvectors=True) # evalues [bs * 68, 2], evectors [bs * 68, 2, 2]
+ evalues = evalues.view(bs, npoints, 2).to(heatmap)
+ evectors = evectors.view(bs, npoints, 2, 2).to(heatmap)
+
+ # STAR Loss
+ # Ambiguity-guided Decomposition
+ loss_trans = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors)
+ # Eigenvalue Restriction
+ loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints)
+ star_loss = loss_trans + self.w * loss_eigen
+
+ return star_loss.mean()
diff --git a/external/landmark_detection/lib/loss/wingLoss.py b/external/landmark_detection/lib/loss/wingLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab3868030c6b91877f04bc7e296b0059519e4ee
--- /dev/null
+++ b/external/landmark_detection/lib/loss/wingLoss.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+
+import math
+import torch
+from torch import nn
+
+
+# torch.log and math.log is e based
+class WingLoss(nn.Module):
+ def __init__(self, omega=0.01, epsilon=2):
+ super(WingLoss, self).__init__()
+ self.omega = omega
+ self.epsilon = epsilon
+
+ def forward(self, pred, target):
+ y = target
+ y_hat = pred
+ delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False)
+ # delta = delta_2.sqrt()
+ delta = delta_2.clamp(min=1e-6).sqrt()
+ C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
+ loss = torch.where(
+ delta < self.omega,
+ self.omega * torch.log(1 + delta / self.epsilon),
+ delta - C
+ )
+ return loss.mean()
diff --git a/external/landmark_detection/lib/metric/__init__.py b/external/landmark_detection/lib/metric/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a872b3320da6ed683027882a18290bdfa6e36026
--- /dev/null
+++ b/external/landmark_detection/lib/metric/__init__.py
@@ -0,0 +1,11 @@
+from .nme import NME
+from .accuracy import Accuracy
+from .fr_and_auc import FR_AUC
+from .params import count_parameters_in_MB
+
+__all__ = [
+ "NME",
+ "Accuracy",
+ "FR_AUC",
+ 'count_parameters_in_MB',
+]
diff --git a/external/landmark_detection/lib/metric/accuracy.py b/external/landmark_detection/lib/metric/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..082ddac7826b6ff69f147629cdaa373ad3a7b699
--- /dev/null
+++ b/external/landmark_detection/lib/metric/accuracy.py
@@ -0,0 +1,21 @@
+import torch
+import torch.nn.functional as F
+
+class Accuracy:
+ def __init__(self):
+ pass
+
+ def __repr__(self):
+ return "Accuracy()"
+
+ def test(self, label_pd, label_gt, ignore_label=-1):
+ correct_cnt = 0
+ total_cnt = 0
+ with torch.no_grad():
+ label_pd = F.softmax(label_pd, dim=1)
+ label_pd = torch.max(label_pd, 1)[1]
+ label_gt = label_gt.long()
+ c = (label_pd == label_gt)
+ correct_cnt = torch.sum(c).item()
+ total_cnt = c.size(0) - torch.sum(label_gt==ignore_label).item()
+ return correct_cnt, total_cnt
diff --git a/external/landmark_detection/lib/metric/fr_and_auc.py b/external/landmark_detection/lib/metric/fr_and_auc.py
new file mode 100644
index 0000000000000000000000000000000000000000..995e4de5a2ef449773b0580ffaee7549deefee1c
--- /dev/null
+++ b/external/landmark_detection/lib/metric/fr_and_auc.py
@@ -0,0 +1,25 @@
+import numpy as np
+from scipy.integrate import simps
+
+
+class FR_AUC:
+ def __init__(self, data_definition):
+ self.data_definition = data_definition
+ if data_definition == '300W':
+ self.thresh = 0.05
+ else:
+ self.thresh = 0.1
+
+ def __repr__(self):
+ return "FR_AUC()"
+
+ def test(self, nmes, thres=None, step=0.0001):
+ if thres is None:
+ thres = self.thresh
+
+ num_data = len(nmes)
+ xs = np.arange(0, thres + step, step)
+ ys = np.array([np.count_nonzero(nmes <= x) for x in xs]) / float(num_data)
+ fr = 1.0 - ys[-1]
+ auc = simps(ys, x=xs) / thres
+ return [round(fr, 4), round(auc, 6)]
diff --git a/external/landmark_detection/lib/metric/nme.py b/external/landmark_detection/lib/metric/nme.py
new file mode 100644
index 0000000000000000000000000000000000000000..942fba3fecc12e2f3e21c477280d84f219e8f856
--- /dev/null
+++ b/external/landmark_detection/lib/metric/nme.py
@@ -0,0 +1,39 @@
+import torch
+import numpy as np
+
+class NME:
+ def __init__(self, nme_left_index, nme_right_index):
+ self.nme_left_index = nme_left_index
+ self.nme_right_index = nme_right_index
+
+ def __repr__(self):
+ return "NME()"
+
+ def get_norm_distance(self, landmarks):
+ assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.'
+ assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.'
+ right_pupil = landmarks[self.nme_right_index, :].mean(0)
+ left_pupil = landmarks[self.nme_left_index, :].mean(0)
+ norm_distance = np.linalg.norm(right_pupil - left_pupil)
+ return norm_distance
+
+ def test(self, label_pd, label_gt):
+ nme_list = []
+ label_pd = label_pd.data.cpu().numpy()
+ label_gt = label_gt.data.cpu().numpy()
+
+ for i in range(label_gt.shape[0]):
+ landmarks_gt = label_gt[i]
+ landmarks_pv = label_pd[i]
+ if isinstance(self.nme_right_index, list):
+ norm_distance = self.get_norm_distance(landmarks_gt)
+ elif isinstance(self.nme_right_index, int):
+ norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index])
+ else:
+ raise NotImplementedError
+ landmarks_delta = landmarks_pv - landmarks_gt
+ nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean()
+ nme_list.append(nme)
+ # sum_nme += nme
+ # total_cnt += 1
+ return nme_list
diff --git a/external/landmark_detection/lib/metric/params.py b/external/landmark_detection/lib/metric/params.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78cfd0ba1a2f26410b18884073e37c8d17a38b
--- /dev/null
+++ b/external/landmark_detection/lib/metric/params.py
@@ -0,0 +1,7 @@
+import torch.nn as nn
+
+def count_parameters_in_MB(model):
+ if isinstance(model, nn.Module):
+ return sum(v.numel() for v in model.parameters()) / 1e6
+ else:
+ return sum(v.numel() for v in model) / 1e6
\ No newline at end of file
diff --git a/external/landmark_detection/lib/utility.py b/external/landmark_detection/lib/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52e736ea20e2a25d094daf2be2a7529c27342b9
--- /dev/null
+++ b/external/landmark_detection/lib/utility.py
@@ -0,0 +1,362 @@
+import json
+import os.path as osp
+import time
+import torch
+import numpy as np
+from tqdm import tqdm
+
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader, DistributedSampler
+import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
+import torch.nn.functional as F
+
+# private package
+from external.landmark_detection.conf import *
+from external.landmark_detection.lib.dataset import AlignmentDataset
+from external.landmark_detection.lib.backbone import StackedHGNetV1
+from external.landmark_detection.lib.loss import *
+from external.landmark_detection.lib.metric import NME, FR_AUC
+from external.landmark_detection.lib.utils import convert_secs2time
+from external.landmark_detection.lib.utils import AverageMeter
+
+
+def get_config(args):
+ config = None
+ config_name = args.config_name
+ config = Alignment(args)
+
+
+ return config
+
+
+def get_dataset(config, tsv_file, image_dir, loader_type, is_train):
+ dataset = None
+ if loader_type == "alignment":
+ dataset = AlignmentDataset(
+ tsv_file,
+ image_dir,
+ transforms.Compose([transforms.ToTensor()]),
+ config.width,
+ config.height,
+ config.channels,
+ config.means,
+ config.scale,
+ config.classes_num,
+ config.crop_op,
+ config.aug_prob,
+ config.edge_info,
+ config.flip_mapping,
+ is_train,
+ encoder_type=config.encoder_type
+ )
+ else:
+ assert False
+ return dataset
+
+
+def get_dataloader(config, data_type, world_rank=0, world_size=1):
+ loader = None
+ if data_type == "train":
+ dataset = get_dataset(
+ config,
+ config.train_tsv_file,
+ config.train_pic_dir,
+ config.loader_type,
+ is_train=True)
+ if world_size > 1:
+ sampler = DistributedSampler(dataset, rank=world_rank, num_replicas=world_size, shuffle=True)
+ loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size // world_size,
+ num_workers=config.train_num_workers, pin_memory=True, drop_last=True)
+ else:
+ loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True,
+ num_workers=config.train_num_workers)
+ elif data_type == "val":
+ dataset = get_dataset(
+ config,
+ config.val_tsv_file,
+ config.val_pic_dir,
+ config.loader_type,
+ is_train=False)
+ loader = DataLoader(dataset, shuffle=False, batch_size=config.val_batch_size,
+ num_workers=config.val_num_workers)
+ elif data_type == "test":
+ dataset = get_dataset(
+ config,
+ config.test_tsv_file,
+ config.test_pic_dir,
+ config.loader_type,
+ is_train=False)
+ loader = DataLoader(dataset, shuffle=False, batch_size=config.test_batch_size,
+ num_workers=config.test_num_workers)
+ else:
+ assert False
+ return loader
+
+
+def get_optimizer(config, net):
+ params = net.parameters()
+
+ optimizer = None
+ if config.optimizer == "sgd":
+ optimizer = optim.SGD(
+ params,
+ lr=config.learn_rate,
+ momentum=config.momentum,
+ weight_decay=config.weight_decay,
+ nesterov=config.nesterov)
+ elif config.optimizer == "adam":
+ optimizer = optim.Adam(
+ params,
+ lr=config.learn_rate)
+ elif config.optimizer == "rmsprop":
+ optimizer = optim.RMSprop(
+ params,
+ lr=config.learn_rate,
+ momentum=config.momentum,
+ alpha=config.alpha,
+ eps=config.epsilon,
+ weight_decay=config.weight_decay
+ )
+ else:
+ assert False
+ return optimizer
+
+
+def get_scheduler(config, optimizer):
+ if config.scheduler == "MultiStepLR":
+ scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
+ else:
+ assert False
+ return scheduler
+
+
+def get_net(config):
+ net = None
+ if config.net == "stackedHGnet_v1":
+ net = StackedHGNetV1(config=config,
+ classes_num=config.classes_num,
+ edge_info=config.edge_info,
+ nstack=config.nstack,
+ add_coord=config.add_coord,
+ decoder_type=config.decoder_type)
+ else:
+ assert False
+ return net
+
+
+def get_criterions(config):
+ criterions = list()
+ for k in range(config.label_num):
+ if config.criterions[k] == "AWingLoss":
+ criterion = AWingLoss()
+ elif config.criterions[k] == "smoothl1":
+ criterion = SmoothL1Loss()
+ elif config.criterions[k] == "l1":
+ criterion = F.l1_loss
+ elif config.criterions[k] == 'l2':
+ criterion = F.mse_loss
+ elif config.criterions[k] == "STARLoss":
+ criterion = STARLoss(dist=config.star_dist, w=config.star_w)
+ elif config.criterions[k] == "STARLoss_v2":
+ criterion = STARLoss_v2(dist=config.star_dist, w=config.star_w)
+ else:
+ assert False
+ criterions.append(criterion)
+ return criterions
+
+
+def set_environment(config):
+ if config.device_id >= 0:
+ assert torch.cuda.is_available() and torch.cuda.device_count() > config.device_id
+ torch.cuda.empty_cache()
+ config.device = torch.device("cuda", config.device_id)
+ config.use_gpu = True
+ else:
+ config.device = torch.device("cpu")
+ config.use_gpu = False
+
+ torch.set_default_dtype(torch.float32)
+ torch.set_default_tensor_type(torch.FloatTensor)
+ torch.set_flush_denormal(True) # ignore extremely small value
+ torch.backends.cudnn.benchmark = True # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
+ torch.autograd.set_detect_anomaly(True)
+
+
+def forward(config, test_loader, net):
+ # ave_metrics = [[0, 0] for i in range(config.label_num)]
+ list_nmes = [[] for i in range(config.label_num)]
+ metric_nme = NME(nme_left_index=config.nme_left_index, nme_right_index=config.nme_right_index)
+ metric_fr_auc = FR_AUC(data_definition=config.data_definition)
+
+ output_pd = None
+
+ net = net.float().to(config.device)
+ net.eval()
+ dataset_size = len(test_loader.dataset)
+ batch_size = test_loader.batch_size
+ if config.logger is not None:
+ config.logger.info("Forward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size))
+ for i, sample in enumerate(tqdm(test_loader)):
+ input = sample["data"].float().to(config.device, non_blocking=True)
+ labels = list()
+ if isinstance(sample["label"], list):
+ for label in sample["label"]:
+ label = label.float().to(config.device, non_blocking=True)
+ labels.append(label)
+ else:
+ label = sample["label"].float().to(config.device, non_blocking=True)
+ for k in range(label.shape[1]):
+ labels.append(label[:, k])
+ labels = config.nstack * labels
+
+ with torch.no_grad():
+ output, heatmap, landmarks = net(input)
+
+ # metrics
+ for k in range(config.label_num):
+ if config.metrics[k] is not None:
+ list_nmes[k] += metric_nme.test(output[k], labels[k])
+
+ metrics = [[np.mean(nmes), ] + metric_fr_auc.test(nmes) for nmes in list_nmes]
+
+ return output_pd, metrics
+
+
+def compute_loss(config, criterions, output, labels, heatmap=None, landmarks=None):
+ batch_weight = 1.0
+ sum_loss = 0
+ losses = list()
+ for k in range(config.label_num):
+ if config.criterions[k] in ['smoothl1', 'l1', 'l2', 'WingLoss', 'AWingLoss']:
+ loss = criterions[k](output[k], labels[k])
+ elif config.criterions[k] in ["STARLoss", "STARLoss_v2"]:
+ _k = int(k / 3) if config.use_AAM else k
+ loss = criterions[k](heatmap[_k], labels[k])
+ else:
+ assert NotImplementedError
+ loss = batch_weight * loss
+ sum_loss += config.loss_weights[k] * loss
+ loss = float(loss.data.cpu().item())
+ losses.append(loss)
+ return losses, sum_loss
+
+
+def forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, epoch):
+ train_model_time = AverageMeter()
+ ave_losses = [0] * config.label_num
+
+ net_module = net_module.float().to(config.device)
+ net_module.train(True)
+ dataset_size = len(train_loader.dataset)
+ batch_size = config.batch_size # train_loader.batch_size
+ batch_num = max(dataset_size / max(batch_size, 1), 1)
+ if config.logger is not None:
+ config.logger.info(config.note)
+ config.logger.info("Forward Backward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size))
+
+ iter_num = len(train_loader)
+ epoch_start_time = time.time()
+ if net_module != net:
+ train_loader.sampler.set_epoch(epoch)
+ for iter, sample in enumerate(train_loader):
+ iter_start_time = time.time()
+ # input
+ input = sample["data"].float().to(config.device, non_blocking=True)
+ # labels
+ labels = list()
+ if isinstance(sample["label"], list):
+ for label in sample["label"]:
+ label = label.float().to(config.device, non_blocking=True)
+ labels.append(label)
+ else:
+ label = sample["label"].float().to(config.device, non_blocking=True)
+ for k in range(label.shape[1]):
+ labels.append(label[:, k])
+ labels = config.nstack * labels
+ # forward
+ output, heatmaps, landmarks = net_module(input)
+
+ # loss
+ losses, sum_loss = compute_loss(config, criterions, output, labels, heatmaps, landmarks)
+ ave_losses = list(map(sum, zip(ave_losses, losses)))
+
+ # backward
+ optimizer.zero_grad()
+ with torch.autograd.detect_anomaly():
+ sum_loss.backward()
+ # torch.nn.utils.clip_grad_norm_(net_module.parameters(), 128.0)
+ optimizer.step()
+
+ if net_ema is not None:
+ accumulate_net(net_ema, net, 0.5 ** (config.batch_size / 10000.0))
+ # accumulate_net(net_ema, net, 0.5 ** (8 / 10000.0))
+
+ # output
+ train_model_time.update(time.time() - iter_start_time)
+ last_time = convert_secs2time(train_model_time.avg * (iter_num - iter - 1), True)
+ if iter % config.display_iteration == 0 or iter + 1 == len(train_loader):
+ if config.logger is not None:
+ losses_str = ' Average Loss: {:.6f}'.format(sum(losses) / len(losses))
+ for k, loss in enumerate(losses):
+ losses_str += ', L{}: {:.3f}'.format(k, loss)
+ config.logger.info(
+ ' -->>[{:03d}/{:03d}][{:03d}/{:03d}]'.format(epoch, config.max_epoch, iter, iter_num) \
+ + last_time + losses_str)
+
+ epoch_end_time = time.time()
+ epoch_total_time = epoch_end_time - epoch_start_time
+ epoch_load_data_time = epoch_total_time - train_model_time.sum
+ if config.logger is not None:
+ config.logger.info("Train/Epoch: %d/%d, Average total time cost per iteration in this epoch: %.6f" % (
+ epoch, config.max_epoch, epoch_total_time / iter_num))
+ config.logger.info("Train/Epoch: %d/%d, Average loading data time cost per iteration in this epoch: %.6f" % (
+ epoch, config.max_epoch, epoch_load_data_time / iter_num))
+ config.logger.info("Train/Epoch: %d/%d, Average training model time cost per iteration in this epoch: %.6f" % (
+ epoch, config.max_epoch, train_model_time.avg))
+
+ ave_losses = [loss / iter_num for loss in ave_losses]
+ if config.logger is not None:
+ config.logger.info("Train/Epoch: %d/%d, Average Loss in this epoch: %.6f" % (
+ epoch, config.max_epoch, sum(ave_losses) / len(ave_losses)))
+ for k, ave_loss in enumerate(ave_losses):
+ if config.logger is not None:
+ config.logger.info("Train/Loss%03d in this epoch: %.6f" % (k, ave_loss))
+
+
+def accumulate_net(model1, model2, decay):
+ """
+ operation: model1 = model1 * decay + model2 * (1 - decay)
+ """
+ par1 = dict(model1.named_parameters())
+ par2 = dict(model2.named_parameters())
+ for k in par1.keys():
+ par1[k].data.mul_(decay).add_(
+ other=par2[k].data.to(par1[k].data.device),
+ alpha=1 - decay)
+
+ par1 = dict(model1.named_buffers())
+ par2 = dict(model2.named_buffers())
+ for k in par1.keys():
+ if par1[k].data.is_floating_point():
+ par1[k].data.mul_(decay).add_(
+ other=par2[k].data.to(par1[k].data.device),
+ alpha=1 - decay)
+ else:
+ par1[k].data = par2[k].data.to(par1[k].data.device)
+
+
+def save_model(config, epoch, net, net_ema, optimizer, scheduler, pytorch_model_path):
+ # save pytorch model
+ state = {
+ "net": net.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "epoch": epoch
+ }
+ if config.ema:
+ state["net_ema"] = net_ema.state_dict()
+
+ torch.save(state, pytorch_model_path)
+ if config.logger is not None:
+ config.logger.info("Epoch: %d/%d, model saved in this epoch" % (epoch, config.max_epoch))
diff --git a/external/landmark_detection/lib/utils/__init__.py b/external/landmark_detection/lib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd9f8f789132cea653b36bb14f333e3208dbdb2
--- /dev/null
+++ b/external/landmark_detection/lib/utils/__init__.py
@@ -0,0 +1,16 @@
+from .meter import AverageMeter
+from .time_utils import time_print, time_string, time_string_short, time_for_file
+from .time_utils import convert_secs2time, convert_size2str
+from .vis_utils import plot_points
+
+__all__ = [
+ "AverageMeter",
+ "time_print",
+ "time_string",
+ "time_string_short",
+ "time_for_file",
+ "convert_size2str",
+ "convert_secs2time",
+
+ "plot_points",
+]
diff --git a/external/landmark_detection/lib/utils/dist_utils.py b/external/landmark_detection/lib/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f9a00a07dab89aad8146c927adef22fbcf4e06b
--- /dev/null
+++ b/external/landmark_detection/lib/utils/dist_utils.py
@@ -0,0 +1,183 @@
+import torch
+from torch.autograd import Variable
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+
+def get_channel_sum(input):
+ """
+ Generates the sum of each channel of the input
+ input = batch_size x 68 x 64 x 64
+ output = batch_size x 68
+ """
+ temp = torch.sum(input, dim=3)
+ output = torch.sum(temp, dim=2)
+
+ return output
+
+
+def expand_two_dimensions_at_end(input, dim1, dim2):
+ """
+ Adds two more dimensions to the end of the input
+ input = batch_size x 68
+ output= batch_size x 68 x dim1 x dim2
+ """
+ input = input.unsqueeze(-1).unsqueeze(-1)
+ input = input.expand(-1, -1, dim1, dim2)
+
+ return input
+
+
+class Distribution(object):
+ def __init__(self, heatmaps, num_dim_dist=2, EPSILON=1e-5, is_normalize=True):
+ self.heatmaps = heatmaps
+ self.num_dim_dist = num_dim_dist
+ self.EPSILON = EPSILON
+ self.is_normalize = is_normalize
+ batch, npoints, h, w = heatmaps.shape
+ # normalize
+ heatmap_sum = torch.clamp(heatmaps.sum([2, 3]), min=1e-6)
+ self.heatmaps = heatmaps / heatmap_sum.view(batch, npoints, 1, 1)
+
+ # means [batch_size x 68 x 2]
+ self.mean = self.get_spatial_mean(self.heatmaps)
+ # covars [batch_size x 68 x 2 x 2]
+ self.covars = self.get_covariance_matrix(self.heatmaps, self.mean)
+
+ _covars = self.covars.view(batch * npoints, 2, 2).cpu()
+ evalues, evectors = _covars.symeig(eigenvectors=True)
+ # eigenvalues [batch_size x 68 x 2]
+ self.evalues = evalues.view(batch, npoints, 2).to(heatmaps)
+ # eignvectors [batch_size x 68 x 2 x 2]
+ self.evectors = evectors.view(batch, npoints, 2, 2).to(heatmaps)
+
+ def __repr__(self):
+ return "Distribution()"
+
+ def plot(self, heatmap, mean, evalues, evectors):
+ # heatmap is not normalized
+ plt.figure(0)
+ if heatmap.is_cuda:
+ heatmap, mean = heatmap.cpu(), mean.cpu()
+ evalues, evectors = evalues.cpu(), evectors.cpu()
+ sns.heatmap(heatmap, cmap="RdBu_r")
+ for evalue, evector in zip(evalues, evectors):
+ plt.arrow(mean[0], mean[1], evalue * evector[0], evalue * evector[1],
+ width=0.2, shape="full")
+ plt.show()
+
+ def easy_plot(self, index):
+ # index = (num of batch_size, num of num_points)
+ num_bs, num_p = index
+ heatmap = self.heatmaps[num_bs, num_p]
+ mean = self.mean[num_bs, num_p]
+ evalues = self.evalues[num_bs, num_p]
+ evectors = self.evectors[num_bs, num_p]
+ self.plot(heatmap, mean, evalues, evectors)
+
+ def project_and_scale(self, pts, eigenvalues, eigenvectors):
+ batch_size, npoints, _ = pts.shape
+ proj_pts = torch.matmul(pts.view(batch_size, npoints, 1, 2), eigenvectors)
+ scale_proj_pts = proj_pts.view(batch_size, npoints, 2) / torch.sqrt(eigenvalues)
+ return scale_proj_pts
+
+ def _make_grid(self, h, w):
+ if self.is_normalize:
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float() / (h - 1) * 2 - 1,
+ torch.arange(w).float() / (w - 1) * 2 - 1)
+ else:
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float(),
+ torch.arange(w).float()
+ )
+
+ return yy, xx
+
+ def get_spatial_mean(self, heatmap):
+ batch, npoints, h, w = heatmap.shape
+
+ yy, xx = self._make_grid(h, w)
+ yy = yy.view(1, 1, h, w).to(heatmap)
+ xx = xx.view(1, 1, h, w).to(heatmap)
+
+ yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
+ xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
+ return coords
+
+ def get_covariance_matrix(self, htp, means):
+ """
+ Covariance calculation from the normalized heatmaps
+ Reference https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_covariance
+ The unbiased estimate is given by
+ Unbiased covariance =
+ ___
+ \
+ /__ w_i (x_i - \mu_i)^T (x_i - \mu_i)
+
+ ___________________________________________
+
+ V_1 - (V_2/V_1)
+
+ ___ ___
+ \ \
+ where V_1 = /__ w_i and V_2 = /__ w_i^2
+
+
+ Input:
+ htp = batch_size x 68 x 64 x 64
+ means = batch_size x 68 x 2
+
+ Output:
+ covariance = batch_size x 68 x 2 x 2
+ """
+ batch_size = htp.shape[0]
+ num_points = htp.shape[1]
+ height = htp.shape[2]
+ width = htp.shape[3]
+
+ yv, xv = self._make_grid(height, width)
+ xv = Variable(xv)
+ yv = Variable(yv)
+
+ if htp.is_cuda:
+ xv = xv.cuda()
+ yv = yv.cuda()
+
+ xmean = means[:, :, 0]
+ xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
+ width) # batch_size x 68 x 64 x 64
+ ymean = means[:, :, 1]
+ yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
+ width) # batch_size x 68 x 64 x 64
+
+ # These are the unweighted versions
+ wt_xv_minus_mean = xv_minus_mean
+ wt_yv_minus_mean = yv_minus_mean
+
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # batch_size*68 x 4096
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1,
+ height * width) # batch_size*68 x 1 x 4096
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # batch_size*68 x 4096
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1,
+ height * width) # batch_size*68 x 1 x 4096
+ vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # batch_size*68 x 2 x 4096
+
+ htp_vec = htp.view(batch_size * num_points, 1, height * width)
+ htp_vec = htp_vec.expand(-1, 2, -1)
+
+ # Torch batch matrix multiplication
+ # https://pytorch.org/docs/stable/torch.html#torch.bmm
+ # Also use the heatmap as the weights at one place now
+ covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # batch_size*68 x 2 x 2
+ covariance = covariance.view(batch_size, num_points, self.num_dim_dist,
+ self.num_dim_dist) # batch_size x 68 x 2 x 2
+
+ V_1 = get_channel_sum(htp) + self.EPSILON # batch_size x 68
+ V_2 = get_channel_sum(torch.pow(htp, 2)) # batch_size x 68
+ denominator = V_1 - (V_2 / V_1)
+
+ covariance = covariance / expand_two_dimensions_at_end(denominator, self.num_dim_dist, self.num_dim_dist)
+
+ return (covariance)
diff --git a/external/landmark_detection/lib/utils/meter.py b/external/landmark_detection/lib/utils/meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff766dff92255a7ed8106f29ccecdc08da1c139
--- /dev/null
+++ b/external/landmark_detection/lib/utils/meter.py
@@ -0,0 +1,20 @@
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0.0
+ self.avg = 0.0
+ self.sum = 0.0
+ self.count = 0.0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __repr__(self):
+ return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
\ No newline at end of file
diff --git a/external/landmark_detection/lib/utils/time_utils.py b/external/landmark_detection/lib/utils/time_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcdcc0590cac7851b22360f4db2b0499d1b48dd7
--- /dev/null
+++ b/external/landmark_detection/lib/utils/time_utils.py
@@ -0,0 +1,49 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+import time, sys
+import numpy as np
+
+
+def time_for_file():
+ ISOTIMEFORMAT = '%d-%h-at-%H-%M-%S'
+ return '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
+
+
+def time_string():
+ ISOTIMEFORMAT = '%Y-%m-%d %X'
+ string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
+ return string
+
+
+def time_string_short():
+ ISOTIMEFORMAT = '%Y%m%d'
+ string = '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
+ return string
+
+
+def time_print(string, is_print=True):
+ if (is_print):
+ print('{} : {}'.format(time_string(), string))
+
+
+def convert_size2str(torch_size):
+ dims = len(torch_size)
+ string = '['
+ for idim in range(dims):
+ string = string + ' {}'.format(torch_size[idim])
+ return string + ']'
+
+
+def convert_secs2time(epoch_time, return_str=False):
+ need_hour = int(epoch_time / 3600)
+ need_mins = int((epoch_time - 3600 * need_hour) / 60)
+ need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
+ if return_str:
+ str = '[Time Left: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
+ return str
+ else:
+ return need_hour, need_mins, need_secs
diff --git a/external/landmark_detection/lib/utils/vis_utils.py b/external/landmark_detection/lib/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26cc48e907efb67734e159eba2363b4822840e1
--- /dev/null
+++ b/external/landmark_detection/lib/utils/vis_utils.py
@@ -0,0 +1,31 @@
+import cv2
+import numpy as np
+import numbers
+
+
+def plot_points(vis, points, radius=1, color=(255, 255, 0), shift=4, indexes=0, is_index=False):
+ if isinstance(points, list):
+ num_point = len(points)
+ elif isinstance(points, np.numarray):
+ num_point = points.shape[0]
+ else:
+ raise NotImplementedError
+ if isinstance(radius, numbers.Number):
+ radius = np.zeros((num_point)) + radius
+
+ if isinstance(indexes, numbers.Number):
+ indexes = [indexes + i for i in range(num_point)]
+ elif isinstance(indexes, list):
+ pass
+ else:
+ raise NotImplementedError
+
+ factor = (1 << shift)
+ for (index, p, s) in zip(indexes, points, radius):
+ cv2.circle(vis, (int(p[0] * factor + 0.5), int(p[1] * factor + 0.5)),
+ int(s * factor), color, 1, cv2.LINE_AA, shift=shift)
+ if is_index:
+ vis = cv2.putText(vis, str(index), (int(p[0]), int(p[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.2,
+ (255, 255, 255), 1)
+
+ return vis
diff --git a/external/landmark_detection/requirements.txt b/external/landmark_detection/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..382636fe60b98f1ca13596f43ccfe791e1f4e860
--- /dev/null
+++ b/external/landmark_detection/requirements.txt
@@ -0,0 +1,19 @@
+tqdm
+torch==1.6.0
+torchvision==0.7.0
+python-gflags==3.1.2
+pandas==0.24.2
+pillow==6.0.0
+numpy==1.16.4
+opencv-python==4.1.0.25
+imageio==2.5.0
+imgaug==0.2.9
+lmdb==0.98
+lxml==4.5.0
+tensorboard==2.4.1
+protobuf==3.20
+tensorboardX==1.8
+# pyarrow==0.17.1
+# wandb==0.10.25
+# https://pytorch.org/get-started/previous-versions/
+# pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
diff --git a/external/landmark_detection/tester.py b/external/landmark_detection/tester.py
new file mode 100644
index 0000000000000000000000000000000000000000..50ae7946c3cd01bfc6525a4a29c18819ba2f4ac9
--- /dev/null
+++ b/external/landmark_detection/tester.py
@@ -0,0 +1,49 @@
+import os
+import torch
+from lib import utility
+
+
+def test(args):
+ # conf
+ config = utility.get_config(args)
+ config.device_id = args.device_ids[0]
+
+ # set environment
+ utility.set_environment(config)
+ config.init_instance()
+ if config.logger is not None:
+ config.logger.info("Loaded configure file %s: %s" % (args.config_name, config.id))
+ config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()]))
+
+ # model
+ net = utility.get_net(config)
+ model_path = os.path.join(config.model_dir,
+ "train.pkl") if args.pretrained_weight is None else args.pretrained_weight
+ if args.device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+
+ net.load_state_dict(checkpoint["net"])
+
+ if config.logger is not None:
+ config.logger.info("Loaded network")
+ # config.logger.info('Net flops: {} G, params: {} MB'.format(flops/1e9, params/1e6))
+
+ # data - test
+ test_loader = utility.get_dataloader(config, "test")
+
+ if config.logger is not None:
+ config.logger.info("Loaded data from {:}".format(config.test_tsv_file))
+
+ # inference
+ result, metrics = utility.forward(config, test_loader, net)
+ if config.logger is not None:
+ config.logger.info("Finished inference")
+
+ # output
+ for k, metric in enumerate(metrics):
+ if config.logger is not None and len(metric) != 0:
+ config.logger.info(
+ "Tested {} dataset, the Size is {}, Metric: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format(
+ config.type, len(test_loader.dataset), metric[0], metric[1], metric[2]))
diff --git a/external/landmark_detection/tools/analysis_motivation.py b/external/landmark_detection/tools/analysis_motivation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cbc62810cf572c4e06fc25b36617450e2acdc86
--- /dev/null
+++ b/external/landmark_detection/tools/analysis_motivation.py
@@ -0,0 +1,220 @@
+import glob
+import json
+import os.path as osp
+import numpy as np
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import seaborn as sns
+from pandas import DataFrame
+import pandas as pd
+
+
+def L2(p1, p2):
+ return np.linalg.norm(p1 - p2)
+
+
+def NME(landmarks_gt, landmarks_pv):
+ pts_num = landmarks_gt.shape[0]
+ if pts_num == 29:
+ left_index = 16
+ right_index = 17
+ elif pts_num == 68:
+ left_index = 36
+ right_index = 45
+ elif pts_num == 98:
+ left_index = 60
+ right_index = 72
+
+ nme = 0
+ eye_span = L2(landmarks_gt[left_index], landmarks_gt[right_index])
+ nmeList = []
+ for i in range(pts_num):
+ error = L2(landmarks_pv[i], landmarks_gt[i])
+ _nme = error / eye_span
+ nmeList.append(_nme)
+ nme += _nme
+ nme /= pts_num
+ return nme, nmeList
+
+
+def NME_analysis(listA):
+ for jsonA in listA:
+ pred = np.array(jsonA['pred'])
+ gt = np.array(jsonA['gt'])
+ nme, nmeList = NME(gt, pred)
+ jsonA['nme'] = nme
+ jsonA['nmeList'] = nmeList
+ return listA
+
+
+def nme_analysis(listA):
+ bdy_nmeList = []
+ scene_nmeList = []
+ for jsonA in tqdm(listA):
+ nme = jsonA['nmeList']
+ nme = np.array(nme)
+ bdy_nme = np.mean(nme[:33])
+ scene_nme = np.mean(nme[33:])
+ # scene_nme = np.mean(nme[[33, 35, 40, 38,
+ # 60, 62, 96, 66, 64,
+ # 50, 44, 48, 46,
+ # 68, 70, 97, 74, 72,
+ # 54, 55, 57, 59,
+ # 76, 82, 79, 90, 94, 85, 16]])
+ bdy_nmeList.append(bdy_nme)
+ scene_nmeList.append(scene_nme)
+ print('bdy nme: {:.4f}'.format(np.mean(bdy_nmeList)))
+ print('scene_nmeList: {:.4f}'.format(np.mean(scene_nmeList)))
+
+
+def Energy_analysis(listA, easyThresh=0.02, easyNum=10, hardThresh=0.07, hardNum=10):
+ easyDict = {'energy': [], 'nme': []}
+ hardDict = {'energy': [], 'nme': []}
+
+ _easyNum, _hardNum = 0, 0
+
+ def cal_energy(evalues):
+ evalues = np.array(evalues)
+ # _energy = _energy.max(1)
+ eccentricity = evalues.max(1) / evalues.min(1)
+ # _energy = _energy.sum() / 2
+ _energy = np.mean(eccentricity)
+ return _energy
+
+ for jsonA in tqdm(listA):
+ nme = jsonA['nme']
+ evalues = jsonA['evalues']
+
+ if _easyNum == easyNum and _hardNum == hardNum:
+ break
+
+ if nme < easyThresh and _easyNum < easyNum:
+ energy = cal_energy(evalues)
+ easyDict['energy'].append(energy)
+ easyDict['nme'].append(nme)
+ _easyNum += 1
+ elif nme > hardThresh and _hardNum < hardNum:
+ energy = cal_energy(evalues)
+ hardDict['energy'].append(energy)
+ hardDict['nme'].append(nme)
+ _hardNum += 1
+
+ print('easyThresh: < {}; hardThresh > {}'.format(easyThresh, hardThresh))
+ print(' |nme |energy |num |')
+ print('easy samples: |{:.4f} |{:.4f} |{} |'.format(np.mean(easyDict['nme']),
+ np.mean(easyDict['energy']),
+ len(easyDict['energy'])))
+ print('hard samples: |{:.4f} |{:.4f} |{} |'.format(np.mean(hardDict['nme']),
+ np.mean(hardDict['energy']),
+ len(hardDict['energy'])))
+
+ return easyDict, hardDict
+
+
+def Eccentricity_analysis(listA):
+ eyecornerList = []
+ boundaryList = []
+ for jsonA in listA:
+ evalues = np.array(jsonA['evalues'])
+ eccentricity = evalues.max(1) / evalues.min(1)
+
+ eyecorner = np.mean(eccentricity[[60, 64, 68, 72]])
+ boundary = np.mean(eccentricity[0:33])
+ eyecornerList.append(eyecorner)
+ boundaryList.append(boundary)
+
+ print('eyecorner: {:.4f}'.format(np.mean(eyecornerList)))
+ print('boundary: {:.4f}'.format(np.mean(boundaryList)))
+ return eyecornerList, boundaryList
+
+
+def plot_bar(dataList):
+ x = list(range(98))
+ assert len(x) == len(dataList)
+ _x = 'Landmark Index'
+ # _y = 'elliptical eccentricity (λ1/λ2)'
+ _y = 'PCA Analyze (λ1/λ2)'
+ data = {
+ _x: x,
+ _y: dataList
+ }
+ df = DataFrame(data)
+ plt.figure(figsize=(10, 4))
+ sns.barplot(x=_x, y=_y, data=df)
+ plt.show()
+
+
+def Eccentricity_analysis2(listA, is_vis=False):
+ landmarksList = [[] for i in range(98)]
+ for jsonA in listA:
+ evalues = np.array(jsonA['evalues'])
+ eccentricity = evalues.max(1) / evalues.min(1)
+ for i, e in enumerate(eccentricity):
+ landmarksList[i].append(e)
+ print('Mean value: {:.4f}'.format(np.mean(np.array(landmarksList))))
+ landmarksList = [np.mean(l) for l in landmarksList]
+ if is_vis:
+ plot_bar(landmarksList)
+ return landmarksList
+
+
+def std_analysis2():
+ save_dir = '/apdcephfs/share_1134483/charlinzhou/experiment/cvpr-23/wflw_results'
+ # l2_npy = glob.glob(osp.join(save_dir, '*DSNT*.npy'))
+ l2_npy = glob.glob(osp.join(save_dir, '*MHNLoss_v2_l2*.npy'))
+
+ def npy2std(npyList):
+ datas = [np.load(npy)[np.newaxis, :] for npy in npyList]
+ datas = np.concatenate(datas, axis=0)
+ # denormalization
+ datas = (datas + 1) * 256 / 2
+ mean = datas.mean(axis=0)[np.newaxis, :]
+ dist = np.linalg.norm(datas - mean, axis=-1)
+ std = np.std(dist, 0)
+ print('min: {}, max:{}, mean:{}'.format(std.min(), std.max(), std.mean()))
+ return std
+
+ std1 = npy2std(l2_npy)
+ std1 = std1.mean(0)
+ # plot_bar(std1)
+ bdy_std = np.mean(std1[:33])
+ cofw_std = np.mean(std1[[33, 35, 40, 38,
+ 60, 62, 96, 66, 64,
+ 50, 44, 48, 46,
+ 68, 70, 97, 74, 72,
+ 54, 55, 57, 59,
+ 76, 82, 79, 90, 94, 85, 16]])
+ print('bdy_std: {:.4f}, cofw_std: {:.4f}'.format(bdy_std, cofw_std))
+ print('the ratio of Boundary std and ALL std: {:.4f} / {:.4f}'.format(np.sum(std1[:33]), np.sum(std1)))
+
+
+if __name__ == '__main__':
+ # 4.29模型
+ json_path = '/apdcephfs/share_1134483/charlinzhou/ckpts/STAR/WFLW/WFLW_256x256_adam_ep500_lr0.001_bs128_STARLoss_smoothl1_1_b0183746-161a-4b76-9cb9-8a2059090233/results.json'
+ # 无初始化
+ # json_path = '/apdcephfs/share_1134483/charlinzhou/ckpts/STAR/WFLW/WFLW_256x256_adam_ep500_lr0.001_bs128_STARLoss_smoothl1_1_9cff3656-8ca8-4c3d-a95d-da76f9f76ea5/results.json'
+ # 4.02模型
+ # json_path = '/apdcephfs/share_1134483/charlinzhou/ckpts/STAR/WFLW/WFLW_256x256_adam_ep500_lr0.001_bs128_STARLoss_smoothl1_1_AAM_2d2bb70e-6fdb-459c-baf7-18c89e7a165f/results.json'
+ listA = json.load(open(json_path, 'r'))
+ print('Load Done!')
+ listA = NME_analysis(listA)
+ print('NME analysis Done!')
+ # Exp1: 分析简单样本和困难样本的能量差异
+ easyDict, hardDict = Energy_analysis(listA, easyNum=2500, hardNum=2500, easyThresh=0.03, hardThresh=0.08)
+
+ # Exp2.1: 分析眼角点和轮廓点的斜率差异
+ # eyecornerList, boundaryList = Eccentricity_analysis(listA)
+
+ # Exp2.2: 可视化所有点的斜率分布
+ # landmarksList = Eccentricity_analysis2(listA, is_vis=True)
+
+ # Exp2.3: 可视化所有点的方差分布
+ # std_analysis2()
+
+ # Exp3: 五官和轮廓NME分析
+ # nme_analysis(listA)
+ # print(easyDict)
+ # print(hardDict)
+
+ # nmeList = [jsonA['nme'] for jsonA in listA]
+ # print(len(nmeList))
diff --git a/external/landmark_detection/tools/infinite_loop.py b/external/landmark_detection/tools/infinite_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..275fc170d601e0d445219263bec1b2a67f268170
--- /dev/null
+++ b/external/landmark_detection/tools/infinite_loop.py
@@ -0,0 +1,4 @@
+import time
+
+while True:
+ time.sleep(1)
diff --git a/external/landmark_detection/tools/infinite_loop_gpu.py b/external/landmark_detection/tools/infinite_loop_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd32b9c618b036a3bb976168f8f54b216f760a5
--- /dev/null
+++ b/external/landmark_detection/tools/infinite_loop_gpu.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+
+import os
+import time
+import torch
+import argparse
+
+parser = argparse.ArgumentParser(description='inf')
+parser.add_argument('--gpu', default='1', type=str, help='index of gpu to use')
+args = parser.parse_args()
+
+os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
+
+n = 1000
+
+x = torch.zeros(4, n, n).cuda()
+rest_time = 0.0000000000001
+while True:
+ y = x * x
+ time.sleep(rest_time)
+ y1 = x * x
diff --git a/external/landmark_detection/tools/split_wflw.py b/external/landmark_detection/tools/split_wflw.py
new file mode 100644
index 0000000000000000000000000000000000000000..1946b68c9a4541abb04ae8bac83045a4b8be339d
--- /dev/null
+++ b/external/landmark_detection/tools/split_wflw.py
@@ -0,0 +1,38 @@
+import csv
+import os.path as osp
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+tsv_file = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/WFLW/test.tsv'
+save_folder = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/_WFLW/'
+
+save_tags = ['largepose', 'expression', 'illumination', 'makeup', 'occlusion', 'blur']
+save_tags = ['test_{}_metadata.tsv'.format(t) for t in save_tags]
+save_files = [osp.join(save_folder, t) for t in save_tags]
+save_files = [open(f, 'w', newline='') for f in save_files]
+
+landmark_num = 98
+items = pd.read_csv(tsv_file, sep="\t")
+
+items_num = len(items)
+for index in tqdm(range(items_num)):
+ image_path = items.iloc[index, 0]
+ landmarks_5pts = items.iloc[index, 1]
+ # landmarks_5pts = np.array(list(map(float, landmarks_5pts.split(","))), dtype=np.float32).reshape(5, 2)
+ landmarks_target = items.iloc[index, 2]
+ # landmarks_target = np.array(list(map(float, landmarks_target.split(","))), dtype=np.float32).reshape(landmark_num, 2)
+ scale = items.iloc[index, 3]
+ center_w, center_h = items.iloc[index, 4], items.iloc[index, 5]
+ if len(items.iloc[index]) > 6:
+ tags = np.array(list(map(lambda x: int(float(x)), items.iloc[index, 6].split(","))))
+ else:
+ tags = np.array([])
+ assert len(tags) == 6, '{} v.s. 6'.format(len(tags))
+ for k, tag in enumerate(tags):
+ if tag == 1:
+ save_file = save_files[k]
+ tsv_w = csv.writer(save_file, delimiter='\t')
+ tsv_w.writerow([image_path, landmarks_5pts, landmarks_target, scale, center_w, center_h])
+
+print('Done!')
diff --git a/external/landmark_detection/tools/testtime_pca.py b/external/landmark_detection/tools/testtime_pca.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7634e680f86001139c48914ab09ed65a798570
--- /dev/null
+++ b/external/landmark_detection/tools/testtime_pca.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+
+def get_channel_sum(input):
+ temp = torch.sum(input, dim=3)
+ output = torch.sum(temp, dim=2)
+ return output
+
+
+def expand_two_dimensions_at_end(input, dim1, dim2):
+ input = input.unsqueeze(-1).unsqueeze(-1)
+ input = input.expand(-1, -1, dim1, dim2)
+ return input
+
+
+class TestTimePCA(nn.Module):
+ def __init__(self):
+ super(TestTimePCA, self).__init__()
+
+ def _make_grid(self, h, w):
+ yy, xx = torch.meshgrid(
+ torch.arange(h).float() / (h - 1) * 2 - 1,
+ torch.arange(w).float() / (w - 1) * 2 - 1)
+ return yy, xx
+
+ def weighted_mean(self, heatmap):
+ batch, npoints, h, w = heatmap.shape
+
+ yy, xx = self._make_grid(h, w)
+ yy = yy.view(1, 1, h, w).to(heatmap)
+ xx = xx.view(1, 1, h, w).to(heatmap)
+
+ yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
+ xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
+ return coords
+
+ def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
+ batch_size, num_points, height, width = htp.shape
+
+ yv, xv = self._make_grid(height, width)
+ xv = Variable(xv)
+ yv = Variable(yv)
+
+ if htp.is_cuda:
+ xv = xv.cuda()
+ yv = yv.cuda()
+
+ xmean = means[:, :, 0]
+ xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
+ width) # [batch_size, 68, 64, 64]
+ ymean = means[:, :, 1]
+ yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
+ width) # [batch_size, 68, 64, 64]
+ wt_xv_minus_mean = xv_minus_mean
+ wt_yv_minus_mean = yv_minus_mean
+
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
+ wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
+ vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
+
+ htp_vec = htp.view(batch_size * num_points, 1, height * width)
+ htp_vec = htp_vec.expand(-1, 2, -1)
+
+ covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
+ covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
+
+ V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
+ V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
+
+ denominator = V_1 - (V_2 / V_1)
+ covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
+
+ return covariance
+
+ def forward(self, heatmap, groudtruth):
+
+ batch, npoints, h, w = heatmap.shape
+
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
+ heatmap = heatmap / heatmap_sum.view(batch, npoints, 1, 1)
+
+ # means [batch_size, 68, 2]
+ means = self.weighted_mean(heatmap)
+
+ # covars [batch_size, 68, 2, 2]
+ covars = self.unbiased_weighted_covariance(heatmap, means)
+
+ # eigenvalues [batch_size * 68, 2] , eigenvectors [batch_size * 68, 2, 2]
+ covars = covars.view(batch * npoints, 2, 2).cpu()
+ evalues, evectors = covars.symeig(eigenvectors=True)
+ evalues = evalues.view(batch, npoints, 2)
+ evectors = evectors.view(batch, npoints, 2, 2)
+ means = means.cpu()
+
+ results = [dict() for _ in range(batch)]
+ for i in range(batch):
+ results[i]['pred'] = means[i].numpy().tolist()
+ results[i]['gt'] = groudtruth[i].cpu().numpy().tolist()
+ results[i]['evalues'] = evalues[i].numpy().tolist()
+ results[i]['evectors'] = evectors[i].numpy().tolist()
+
+ return results
diff --git a/external/vgghead_detector/VGGDetector.py b/external/vgghead_detector/VGGDetector.py
new file mode 100644
index 0000000000000000000000000000000000000000..b67ad6f7e5fc8ff6f7735ce55f949a1b6df79287
--- /dev/null
+++ b/external/vgghead_detector/VGGDetector.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
+# Modified based on code from Orest Kupyn (University of Oxford).
+
+import os
+import torch
+import numpy as np
+import torchvision
+
+from .utils_vgghead import nms
+from .utils_lmks_detector import LmksDetector
+
+class VGGHeadDetector(torch.nn.Module):
+ def __init__(self, device,
+ vggheadmodel_path=None):
+ super().__init__()
+ self.image_size = 640
+ self._device = device
+ self.vggheadmodel_path = vggheadmodel_path
+ self._init_models()
+
+ def _init_models(self,):
+ # vgg_heads_l
+ self.model = torch.load(self.vggheadmodel_path, map_location='cpu')
+ self.model.to(self._device).eval()
+
+ @torch.no_grad()
+ def forward(self, image_tensor, image_key, conf_threshold=0.5):
+ if not hasattr(self, 'model'):
+ self._init_models()
+ image_tensor = image_tensor.to(self._device).float()
+ image, padding, scale = self._preprocess(image_tensor)
+ bbox, scores, flame_params = self.model(image)
+ bbox, vgg_results = self._postprocess(bbox, scores, flame_params, conf_threshold)
+
+ if bbox is None:
+ print('VGGHeadDetector: No face detected: {}!'.format(image_key))
+ return None, None, None
+ vgg_results['normalize'] = {'padding': padding, 'scale': scale}
+
+ # bbox
+ bbox = bbox.clip(0, self.image_size)
+ bbox[[0, 2]] -= padding[0]; bbox[[1, 3]] -= padding[1]; bbox /= scale
+ bbox = bbox.clip(0, self.image_size / scale)
+
+ return vgg_results, bbox, None
+
+ def _preprocess(self, image):
+ _, h, w = image.shape
+ if h > w:
+ new_h, new_w = self.image_size, int(w * self.image_size / h)
+ else:
+ new_h, new_w = int(h * self.image_size / w), self.image_size
+ scale = self.image_size / max(h, w)
+ image = torchvision.transforms.functional.resize(image, (new_h, new_w), antialias=True)
+ pad_w = self.image_size - image.shape[2]
+ pad_h = self.image_size - image.shape[1]
+ image = torchvision.transforms.functional.pad(image, (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=127)
+ image = image.unsqueeze(0).float() / 255.0
+ return image, np.array([pad_w // 2, pad_h // 2]), scale
+
+ def _postprocess(self, bbox, scores, flame_params, conf_threshold):
+ # flame_params = {"shape": 300, "exp": 100, "rotation": 6, "jaw": 3, "translation": 3, "scale": 1}
+ bbox, scores, flame_params = nms(bbox, scores, flame_params, confidence_threshold=conf_threshold)
+ if bbox.shape[0] == 0:
+ return None, None
+ max_idx = ((bbox[:, 3] - bbox[:, 1]) * (bbox[:, 2] - bbox[:, 0])).argmax().long()
+ bbox, flame_params = bbox[max_idx], flame_params[max_idx]
+ if bbox[0] < 5 and bbox[1] < 5 and bbox[2] > 635 and bbox[3] > 635:
+ return None, None
+ # flame
+ posecode = torch.cat([flame_params.new_zeros(3), flame_params[400:403]])
+ vgg_results = {
+ 'rotation_6d': flame_params[403:409], 'translation': flame_params[409:412], 'scale': flame_params[412:],
+ 'shapecode': flame_params[:300], 'expcode': flame_params[300:400], 'posecode': posecode,
+ }
+ return bbox, vgg_results
diff --git a/external/vgghead_detector/__init__.py b/external/vgghead_detector/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8b26f25d4bfc0b67c20571ae52e7db9510a54ba
--- /dev/null
+++ b/external/vgghead_detector/__init__.py
@@ -0,0 +1,5 @@
+#!/usr/bin/env python
+# Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
+
+from .VGGDetector import VGGHeadDetector
+from .utils_vgghead import reproject_vertices
diff --git a/external/vgghead_detector/utils_lmks_detector.py b/external/vgghead_detector/utils_lmks_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..71068712cf12ba8d1b6e90e9389b0c4eb5c7b100
--- /dev/null
+++ b/external/vgghead_detector/utils_lmks_detector.py
@@ -0,0 +1,574 @@
+#################################################
+# written by wangduomin@xiaobing.ai #
+# modified by xg.chu@outlook.com #
+#################################################
+import os
+import torch
+import numpy as np
+import torchvision
+os.environ["GLOG_minloglevel"] ="2"
+
+class LmksDetector(torch.nn.Module):
+ def __init__(self, device, model_path):
+ super().__init__()
+ self.size = 256
+ self._device = device
+ # model
+ model = LandmarkDetector(model_path)
+ self.model = model.to(self._device).eval()
+
+ def _transform(self, image, bbox):
+ assert bbox[3]-bbox[1] == bbox[2]-bbox[0], 'Bounding box should be square.'
+ c_image = torchvision.transforms.functional.crop(image, bbox[1], bbox[0], bbox[3]-bbox[1], bbox[2]-bbox[0])
+ c_image = torchvision.transforms.functional.resize(c_image, (self.size, self.size), antialias=True)
+ c_image = torchvision.transforms.functional.normalize(c_image/255.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ return c_image[None], self.size / (bbox[3]-bbox[1])
+
+ @torch.no_grad()
+ def forward(self, image, bbox):
+ assert image.dim() == 3, 'Input must be a 3D tensor.'
+ if image.max() < 2.0:
+ print('Image Should be in 0-255 range, but found in 0-1 range.')
+ bbox = expand_bbox(bbox, ratio=1.38)
+ # image_bbox = torchvision.utils.draw_bounding_boxes(image.cpu().to(torch.uint8), bbox[None], width=3, colors='green')
+ # torchvision.utils.save_image(image_bbox/255.0, 'image_bbox.jpg')
+ c_image, scale = self._transform(image.to(self._device), bbox)
+ landmarks = self.model(c_image).squeeze(0) / scale
+ landmarks = landmarks + bbox[:2][None]
+ landmarks = mapping_lmk98_to_lmk70(landmarks)
+ return landmarks
+
+
+def mapping_lmk98_to_lmk70(lmk98):
+ lmk70 = lmk98[[
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
+ 33, 34, 35, 36, 37, 42, 43, 44, 45, 46,
+ 51, 52, 53, 54, 55, 56, 57, 58, 59,
+ 60, 61, 63, 64, 65, 67,
+ 68, 69, 71, 72, 73, 75,
+ 76, 77, 78, 79, 80, 81, 82, 83, 84, 85,
+ 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97
+ ]]
+ return lmk70
+
+
+def expand_bbox(bbox, ratio=1.0):
+ xmin, ymin, xmax, ymax = bbox
+ cenx, ceny = ((xmin + xmax) / 2).long(), ((ymin + ymax) / 2).long()
+ extend_size = torch.sqrt((ymax - ymin + 1) * (xmax - xmin + 1)) * ratio
+ xmine, xmaxe = cenx - extend_size // 2, cenx + extend_size // 2
+ ymine, ymaxe = ceny - extend_size // 2, ceny + extend_size // 2
+ return torch.stack([xmine, ymine, xmaxe, ymaxe]).long()
+
+
+# ------------------------------------------------------------------------------
+# Reference: https://github.com/HRNet/HRNet-Image-Classification
+# ------------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.model_zoo as model_zoo
+
+__all__ = [ 'hrnet18s', 'hrnet18', 'hrnet32' ]
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes, )
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(False)
+
+ def _check_branches(self, num_branches, blocks, num_blocks,
+ num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index], stride, downsample))
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_inchannels[i],
+ 1,
+ 1,
+ 0,
+ bias=False),
+ nn.BatchNorm2d(num_inchannels[i]),
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i-j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3),
+ nn.ReLU(False)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+class HighResolutionNet(nn.Module):
+
+ def __init__(self, num_modules, num_branches, block,
+ num_blocks, num_channels, fuse_method, **kwargs):
+ super(HighResolutionNet, self).__init__()
+ self.num_modules = num_modules
+ self.num_branches = num_branches
+ self.block = block
+ self.num_blocks = num_blocks
+ self.num_channels = num_channels
+ self.fuse_method = fuse_method
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ # layer1
+ num_channels, num_blocks = self.num_channels[0][0], self.num_blocks[0][0]
+ self.layer1 = self._make_layer(self.block[0], 64, num_channels, num_blocks)
+ stage1_out_channel = self.block[0].expansion*num_channels
+ # layer2
+ num_channels, num_blocks = self.num_channels[1], self.num_blocks[1]
+ num_channels = [
+ num_channels[i] * self.block[1].expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(1, num_channels)
+ # layer3
+ num_channels, num_blocks = self.num_channels[2], self.num_blocks[2]
+ num_channels = [
+ num_channels[i] * self.block[2].expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(2, num_channels)
+ # layer4
+ num_channels, num_blocks = self.num_channels[3], self.num_blocks[3]
+ num_channels = [
+ num_channels[i] * self.block[3].expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(3, num_channels, multi_scale_output=True)
+ self._out_channels = sum(pre_stage_channels)
+
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ 3,
+ 1,
+ 1,
+ bias=False),
+ nn.BatchNorm2d(
+ num_channels_cur_layer[i], ),
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i+1-num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i-num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(
+ inchannels, outchannels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(outchannels, ),
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, ),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride, downsample))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, stage_index, in_channels,
+ multi_scale_output=True):
+ num_modules = self.num_modules[stage_index]
+ num_branches = self.num_branches[stage_index]
+ num_blocks = self.num_blocks[stage_index]
+ num_channels = self.num_channels[stage_index]
+ block = self.block[stage_index]
+ fuse_method = self.fuse_method[stage_index]
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+
+ modules.append(
+ HighResolutionModule(num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output)
+ )
+ in_channels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), in_channels
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.num_branches[1]):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.num_branches[2]):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.num_branches[3]):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ kwargs = {
+ 'size': tuple(y_list[0].shape[-2:]),
+ 'mode': 'bilinear', 'align_corners': False,
+ }
+ return torch.cat([F.interpolate(y,**kwargs) for y in y_list], 1)
+
+def hrnet18s(pretrained=True, **kwargs):
+ model = HighResolutionNet(
+ num_modules = [1, 1, 3, 2],
+ num_branches = [1, 2, 3, 4],
+ block = [Bottleneck, BasicBlock, BasicBlock, BasicBlock],
+ num_blocks = [(2,), (2,2), (2,2,2), (2,2,2,2)],
+ num_channels = [(64,), (18,36), (18,36,72), (18,36,72,144)],
+ fuse_method = ['SUM', 'SUM', 'SUM', 'SUM'],
+ **kwargs
+ )
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['hrnet_w18s']), strict=False)
+ return model
+
+def hrnet18(pretrained=False, **kwargs):
+ model = HighResolutionNet(
+ num_modules = [1, 1, 4, 3],
+ num_branches = [1, 2, 3, 4],
+ block = [Bottleneck, BasicBlock, BasicBlock, BasicBlock],
+ num_blocks = [(4,), (4,4), (4,4,4), (4,4,4,4)],
+ num_channels = [(64,), (18,36), (18,36,72), (18,36,72,144)],
+ fuse_method = ['SUM', 'SUM', 'SUM', 'SUM'],
+ **kwargs
+ )
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['hrnet18']), strict=False)
+ return model
+
+def hrnet32(pretrained=False, **kwargs):
+ model = HighResolutionNet(
+ num_modules = [1, 1, 4, 3],
+ num_branches = [1, 2, 3, 4],
+ block = [Bottleneck, BasicBlock, BasicBlock, BasicBlock],
+ num_blocks = [(4,), (4,4), (4,4,4), (4,4,4,4)],
+ num_channels = [(64,), (32,64), (32,64,128), (32,64,128,256)],
+ fuse_method = ['SUM', 'SUM', 'SUM', 'SUM'],
+ **kwargs
+ )
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['hrnet32']), strict=False)
+ return model
+
+
+class BinaryHeadBlock(nn.Module):
+ """BinaryHeadBlock
+ """
+ def __init__(self, in_channels, proj_channels, out_channels, **kwargs):
+ super(BinaryHeadBlock, self).__init__()
+ self.layers = nn.Sequential(
+ nn.Conv2d(in_channels, proj_channels, 1, bias=False),
+ nn.BatchNorm2d(proj_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(proj_channels, out_channels*2, 1, bias=False),
+ )
+
+ def forward(self, input):
+ N, C, H, W = input.shape
+ return self.layers(input).view(N, 2, -1, H, W)
+
+def heatmap2coord(heatmap, topk=9):
+ N, C, H, W = heatmap.shape
+ score, index = heatmap.view(N,C,1,-1).topk(topk, dim=-1)
+ coord = torch.cat([index%W, index//W], dim=2)
+ return (coord*F.softmax(score, dim=-1)).sum(-1)
+
+class BinaryHeatmap2Coordinate(nn.Module):
+ """BinaryHeatmap2Coordinate
+ """
+ def __init__(self, stride=4.0, topk=5, **kwargs):
+ super(BinaryHeatmap2Coordinate, self).__init__()
+ self.topk = topk
+ self.stride = stride
+
+ def forward(self, input):
+ return self.stride * heatmap2coord(input[:,1,...], self.topk)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ format_string += 'topk={}, '.format(self.topk)
+ format_string += 'stride={}'.format(self.stride)
+ format_string += ')'
+ return format_string
+
+class HeatmapHead(nn.Module):
+ """HeatmapHead
+ """
+ def __init__(self):
+ super(HeatmapHead, self).__init__()
+ self.decoder = BinaryHeatmap2Coordinate(
+ topk=9,
+ stride=4.0,
+ )
+ self.head = BinaryHeadBlock(
+ in_channels=270,
+ proj_channels=270,
+ out_channels=98,
+ )
+
+ def forward(self, input):
+ heatmap = self.head(input)
+ ldmk = self.decoder(heatmap)
+ return heatmap[:,1,...], ldmk
+
+
+class LandmarkDetector(nn.Module):
+ def __init__(self, model_path):
+ super(LandmarkDetector, self).__init__()
+
+ self.backbone = HighResolutionNet(
+ num_modules = [1, 1, 4, 3],
+ num_branches = [1, 2, 3, 4],
+ block = [Bottleneck, BasicBlock, BasicBlock, BasicBlock],
+ num_blocks = [(4,), (4,4), (4,4,4), (4,4,4,4)],
+ num_channels = [(64,), (18,36), (18,36,72), (18,36,72,144)],
+ fuse_method = ['SUM', 'SUM', 'SUM', 'SUM']
+ )
+
+ self.heatmap_head = HeatmapHead()
+
+ self.load_state_dict(torch.load(model_path, map_location='cpu'))
+
+ def forward(self, img):
+ heatmap, landmark = self.heatmap_head(self.backbone(img))
+
+ return landmark
diff --git a/external/vgghead_detector/utils_vgghead.py b/external/vgghead_detector/utils_vgghead.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db2bb3e803bc260479d4ed021dc28575fe8cfcf
--- /dev/null
+++ b/external/vgghead_detector/utils_vgghead.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python
+# Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
+# Modified based on code from Orest Kupyn (University of Oxford).
+
+import torch
+import torchvision
+
+def reproject_vertices(flame_model, vgg_results):
+ # flame_model = FLAMEModel(n_shape=300, n_exp=100, scale=1.0)
+ vertices, _ = flame_model(
+ shape_params=vgg_results['shapecode'],
+ expression_params=vgg_results['expcode'],
+ pose_params=vgg_results['posecode'],
+ verts_sclae=1.0
+ )
+ vertices[:, :, 2] += 0.05 # MESH_OFFSET_Z
+ vgg_landmarks3d = flame_model._vertices2landmarks(vertices)
+ vgg_transform_results = vgg_results['transform']
+ rotation_mat = rot_mat_from_6dof(vgg_transform_results['rotation_6d']).type(vertices.dtype)
+ translation = vgg_transform_results['translation'][:, None, :]
+ scale = torch.clamp(vgg_transform_results['scale'][:, None], 1e-8)
+ rot_vertices = vertices.clone()
+ rot_vertices = torch.matmul(rotation_mat.unsqueeze(1), rot_vertices.unsqueeze(-1))[..., 0]
+ vgg_landmarks3d = torch.matmul(rotation_mat.unsqueeze(1), vgg_landmarks3d.unsqueeze(-1))[..., 0]
+ proj_vertices = (rot_vertices * scale) + translation
+ vgg_landmarks3d = (vgg_landmarks3d * scale) + translation
+
+ trans_padding, trans_scale = vgg_results['normalize']['padding'], vgg_results['normalize']['scale']
+ proj_vertices[:, :, 0] -= trans_padding[:, 0, None]
+ proj_vertices[:, :, 1] -= trans_padding[:, 1, None]
+ proj_vertices = proj_vertices / trans_scale[:, None, None]
+ vgg_landmarks3d[:, :, 0] -= trans_padding[:, 0, None]
+ vgg_landmarks3d[:, :, 1] -= trans_padding[:, 1, None]
+ vgg_landmarks3d = vgg_landmarks3d / trans_scale[:, None, None]
+ return proj_vertices.float()[..., :2], vgg_landmarks3d.float()[..., :2]
+
+
+def rot_mat_from_6dof(v: torch.Tensor) -> torch.Tensor:
+ assert v.shape[-1] == 6
+ v = v.view(-1, 6)
+ vx, vy = v[..., :3].clone(), v[..., 3:].clone()
+
+ b1 = torch.nn.functional.normalize(vx, dim=-1)
+ b3 = torch.nn.functional.normalize(torch.cross(b1, vy, dim=-1), dim=-1)
+ b2 = -torch.cross(b1, b3, dim=1)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+
+def nms(boxes_xyxy, scores, flame_params,
+ confidence_threshold: float = 0.5, iou_threshold: float = 0.5,
+ top_k: int = 1000, keep_top_k: int = 100
+ ):
+ for pred_bboxes_xyxy, pred_bboxes_conf, pred_flame_params in zip(
+ boxes_xyxy.detach().float(),
+ scores.detach().float(),
+ flame_params.detach().float(),
+ ):
+ pred_bboxes_conf = pred_bboxes_conf.squeeze(-1) # [Anchors]
+ conf_mask = pred_bboxes_conf >= confidence_threshold
+
+ pred_bboxes_conf = pred_bboxes_conf[conf_mask]
+ pred_bboxes_xyxy = pred_bboxes_xyxy[conf_mask]
+ pred_flame_params = pred_flame_params[conf_mask]
+
+ # Filter all predictions by self.nms_top_k
+ if pred_bboxes_conf.size(0) > top_k:
+ topk_candidates = torch.topk(pred_bboxes_conf, k=top_k, largest=True, sorted=True)
+ pred_bboxes_conf = pred_bboxes_conf[topk_candidates.indices]
+ pred_bboxes_xyxy = pred_bboxes_xyxy[topk_candidates.indices]
+ pred_flame_params = pred_flame_params[topk_candidates.indices]
+
+ # NMS
+ idx_to_keep = torchvision.ops.boxes.nms(boxes=pred_bboxes_xyxy, scores=pred_bboxes_conf, iou_threshold=iou_threshold)
+
+ final_bboxes = pred_bboxes_xyxy[idx_to_keep][: keep_top_k] # [Instances, 4]
+ final_scores = pred_bboxes_conf[idx_to_keep][: keep_top_k] # [Instances, 1]
+ final_params = pred_flame_params[idx_to_keep][: keep_top_k] # [Instances, Flame Params]
+ return final_bboxes, final_scores, final_params
diff --git a/flame_tracking_single_image.py b/flame_tracking_single_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aad348af8c8207669ed8d3b65968d4af7abfa04
--- /dev/null
+++ b/flame_tracking_single_image.py
@@ -0,0 +1,345 @@
+import argparse
+import json
+import os
+import time
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+import torchvision
+import tyro
+import yaml
+from loguru import logger
+from PIL import Image
+
+from external.human_matting import StyleMatteEngine as HumanMattingEngine
+from external.landmark_detection.FaceBoxesV2.faceboxes_detector import \
+ FaceBoxesDetector
+from external.landmark_detection.infer_image import Alignment
+from external.vgghead_detector import VGGHeadDetector
+from vhap.config.base import BaseTrackingConfig
+from vhap.export_as_nerf_dataset import (NeRFDatasetWriter,
+ TrackedFLAMEDatasetWriter, split_json)
+from vhap.model.tracker import GlobalTracker
+
+# Define error codes for various processing failures.
+ERROR_CODE = {'FailedToDetect': 1, 'FailedToOptimize': 2, 'FailedToExport': 3}
+
+
+def expand_bbox(bbox, scale=1.1):
+ """Expands the bounding box by a given scale."""
+ xmin, ymin, xmax, ymax = bbox.unbind(dim=-1)
+ center_x, center_y = (xmin + xmax) / 2, (ymin + ymax) / 2
+ extension_size = torch.sqrt((ymax - ymin) * (xmax - xmin)) * scale
+ x_min_expanded = center_x - extension_size / 2
+ x_max_expanded = center_x + extension_size / 2
+ y_min_expanded = center_y - extension_size / 2
+ y_max_expanded = center_y + extension_size / 2
+ return torch.stack(
+ [x_min_expanded, y_min_expanded, x_max_expanded, y_max_expanded],
+ dim=-1)
+
+
+def load_config(src_folder: Path):
+ """Load configuration from the given source folder."""
+ config_file_path = src_folder / 'config.yml'
+ if not config_file_path.exists():
+ src_folder = sorted(
+ src_folder.iterdir())[-1] # Get the last modified folder
+ config_file_path = src_folder / 'config.yml'
+ assert config_file_path.exists(), f'File not found: {config_file_path}'
+
+ config_data = yaml.load(config_file_path.read_text(), Loader=yaml.Loader)
+ return src_folder, config_data
+
+
+class FlameTrackingSingleImage:
+ """Class for tracking and processing a single image."""
+ def __init__(
+ self,
+ output_dir,
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
+ detect_iris_landmarks=False):
+
+ logger.info(f'Output Directory: {output_dir}')
+
+ start_time = time.time()
+ logger.info('Loading Pre-trained Models...')
+
+ self.output_dir = output_dir
+ self.output_preprocess = os.path.join(output_dir, 'preprocess')
+ self.output_tracking = os.path.join(output_dir, 'tracking')
+ self.output_export = os.path.join(output_dir, 'export')
+ self.device = 'cuda:0'
+
+ # Load alignment model
+ assert os.path.exists(
+ alignment_model_path), f'{alignment_model_path} does not exist!'
+ args = self._parse_args()
+ args.model_path = alignment_model_path
+ self.alignment = Alignment(args,
+ alignment_model_path,
+ dl_framework='pytorch',
+ device_ids=[0])
+
+ # Load VGG head model
+ assert os.path.exists(
+ vgghead_model_path), f'{vgghead_model_path} does not exist!'
+ self.vgghead_encoder = VGGHeadDetector(
+ device=self.device, vggheadmodel_path=vgghead_model_path)
+
+ # Load human matting model
+ assert os.path.exists(
+ human_matting_path), f'{human_matting_path} does not exist!'
+ self.matting_engine = HumanMattingEngine(
+ device=self.device, human_matting_path=human_matting_path)
+
+ # Load face box detector model
+ assert os.path.exists(
+ facebox_model_path), f'{facebox_model_path} does not exist!'
+ self.detector = FaceBoxesDetector('FaceBoxes', facebox_model_path,
+ True, self.device)
+
+ self.detect_iris_landmarks_flag = detect_iris_landmarks
+ if self.detect_iris_landmarks_flag:
+ from fdlite import FaceDetection, FaceLandmark, IrisLandmark
+ self.iris_detect_faces = FaceDetection()
+ self.iris_detect_face_landmarks = FaceLandmark()
+ self.iris_detect_iris_landmarks = IrisLandmark()
+
+ end_time = time.time()
+ torch.cuda.empty_cache()
+ logger.info(f'Finished Loading Pre-trained Models. Time: '
+ f'{end_time - start_time:.2f}s')
+
+ def _parse_args(self):
+ parser = argparse.ArgumentParser(description='Evaluation script')
+ parser.add_argument('--output_dir',
+ type=str,
+ help='Output directory',
+ default='output')
+ parser.add_argument('--config_name',
+ type=str,
+ help='Configuration name',
+ default='alignment')
+ return parser.parse_args()
+
+ def preprocess(self, input_image_path):
+ """Preprocess the input image for tracking."""
+ if not os.path.exists(input_image_path):
+ logger.warning(f'{input_image_path} does not exist!')
+ return ERROR_CODE['FailedToDetect']
+
+ start_time = time.time()
+ logger.info('Starting Preprocessing...')
+ name_list = []
+ frame_index = 0
+
+ # Bounding box detection
+ frame = torchvision.io.read_image(input_image_path)
+ try:
+ _, frame_bbox, _ = self.vgghead_encoder(frame, frame_index)
+ except Exception:
+ logger.error('Failed to detect face')
+ return ERROR_CODE['FailedToDetect']
+
+ if frame_bbox is None:
+ logger.error('Failed to detect face')
+ return ERROR_CODE['FailedToDetect']
+
+ # Expand bounding box
+ name_list.append('00000.png')
+ frame_bbox = expand_bbox(frame_bbox, scale=1.65).long()
+
+ # Crop and resize
+ cropped_frame = torchvision.transforms.functional.crop(
+ frame,
+ top=frame_bbox[1],
+ left=frame_bbox[0],
+ height=frame_bbox[3] - frame_bbox[1],
+ width=frame_bbox[2] - frame_bbox[0])
+ cropped_frame = torchvision.transforms.functional.resize(
+ cropped_frame, (1024, 1024), antialias=True)
+
+ # Apply matting
+ cropped_frame, mask = self.matting_engine(cropped_frame / 255.0,
+ return_type='matting',
+ background_rgb=1.0)
+ cropped_frame = cropped_frame.cpu() * 255.0
+ saved_image = np.round(cropped_frame.cpu().permute(
+ 1, 2, 0).numpy()).astype(np.uint8)[:, :, (2, 1, 0)]
+
+ # Create output directories if not exist
+ self.sub_output_dir = os.path.join(
+ self.output_preprocess,
+ os.path.splitext(os.path.basename(input_image_path))[0])
+ output_image_dir = os.path.join(self.sub_output_dir, 'images')
+ output_mask_dir = os.path.join(self.sub_output_dir, 'mask')
+ output_alpha_map_dir = os.path.join(self.sub_output_dir, 'alpha_maps')
+
+ os.makedirs(output_image_dir, exist_ok=True)
+ os.makedirs(output_mask_dir, exist_ok=True)
+ os.makedirs(output_alpha_map_dir, exist_ok=True)
+
+ # Save processed image, mask and alpha map
+ cv2.imwrite(os.path.join(output_image_dir, name_list[frame_index]),
+ saved_image)
+ cv2.imwrite(os.path.join(output_mask_dir, name_list[frame_index]),
+ np.array((mask.cpu() * 255.0)).astype(np.uint8))
+ cv2.imwrite(
+ os.path.join(output_alpha_map_dir,
+ name_list[frame_index]).replace('.png', '.jpg'),
+ (np.ones_like(saved_image) * 255).astype(np.uint8))
+
+ # Landmark detection
+ detections, _ = self.detector.detect(saved_image, 0.8, 1)
+ for idx, detection in enumerate(detections):
+ x1_ori, y1_ori = detection[2], detection[3]
+ x2_ori, y2_ori = x1_ori + detection[4], y1_ori + detection[5]
+
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
+ center_w, center_h = (x1_ori + x2_ori) / 2, (y1_ori + y2_ori) / 2
+ scale, center_w, center_h = float(scale), float(center_w), float(
+ center_h)
+
+ face_landmarks = self.alignment.analyze(saved_image, scale,
+ center_w, center_h)
+
+ # Normalize and save landmarks
+ normalized_landmarks = np.zeros((face_landmarks.shape[0], 3))
+ normalized_landmarks[:, :2] = face_landmarks / 1024
+
+ landmark_output_dir = os.path.join(self.sub_output_dir, 'landmark2d')
+ os.makedirs(landmark_output_dir, exist_ok=True)
+
+ landmark_data = {
+ 'bounding_box': [],
+ 'face_landmark_2d': normalized_landmarks[None, ...],
+ }
+
+ landmark_path = os.path.join(landmark_output_dir, 'landmarks.npz')
+ np.savez(landmark_path, **landmark_data)
+
+ if self.detect_iris_landmarks_flag:
+ self._detect_iris_landmarks(
+ os.path.join(output_image_dir, name_list[frame_index]))
+
+ end_time = time.time()
+ torch.cuda.empty_cache()
+ logger.info(
+ f'Finished Processing Image. Time: {end_time - start_time:.2f}s')
+
+ return 0
+
+ def optimize(self):
+ """Optimize the tracking model using configuration data."""
+ start_time = time.time()
+ logger.info('Starting Optimization...')
+
+ tyro.extras.set_accent_color('bright_yellow')
+ config_data = tyro.cli(BaseTrackingConfig)
+
+ config_data.data.sequence = self.sub_output_dir.split('/')[-1]
+ config_data.data.root_folder = Path(
+ os.path.dirname(self.sub_output_dir))
+
+ if not os.path.exists(self.sub_output_dir):
+ logger.error(f'Failed to load {self.sub_output_dir}')
+ return ERROR_CODE['FailedToOptimize']
+
+ config_data.exp.output_folder = Path(self.output_tracking)
+ tracker = GlobalTracker(config_data)
+ tracker.optimize()
+
+ end_time = time.time()
+ torch.cuda.empty_cache()
+ logger.info(
+ f'Finished Optimization. Time: {end_time - start_time:.2f}s')
+
+ return 0
+
+ def _detect_iris_landmarks(self, image_path):
+ """Detect iris landmarks in the given image."""
+ from fdlite import face_detection_to_roi, iris_roi_from_face_landmarks
+
+ img = Image.open(image_path)
+ img_size = (1024, 1024)
+
+ face_detections = self.iris_detect_faces(img)
+ if len(face_detections) != 1:
+ logger.warning('Empty iris landmarks')
+ else:
+ face_detection = face_detections[0]
+ try:
+ face_roi = face_detection_to_roi(face_detection, img_size)
+ except ValueError:
+ logger.warning('Empty iris landmarks')
+ return
+
+ face_landmarks = self.iris_detect_face_landmarks(img, face_roi)
+ if len(face_landmarks) == 0:
+ logger.warning('Empty iris landmarks')
+ return
+
+ iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size)
+
+ if len(iris_rois) != 2:
+ logger.warning('Empty iris landmarks')
+ return
+
+ landmarks = []
+ for iris_roi in iris_rois[::-1]:
+ try:
+ iris_landmarks = self.iris_detect_iris_landmarks(
+ img, iris_roi).iris[0:1]
+ except np.linalg.LinAlgError:
+ logger.warning('Failed to get iris landmarks')
+ break
+
+ # For each landmark, append x and y coordinates scaled to 1024.
+ for landmark in iris_landmarks:
+ landmarks.append(landmark.x * 1024)
+ landmarks.append(landmark.y * 1024)
+
+ landmark_data = {'00000.png': landmarks}
+ json.dump(
+ landmark_data,
+ open(
+ os.path.join(self.sub_output_dir, 'landmark2d',
+ 'iris.json'), 'w'))
+
+ def export(self):
+ """Export the tracking results to configured folder."""
+ logger.info(f'Beginning export from {self.output_tracking}')
+ start_time = time.time()
+ if not os.path.exists(self.output_tracking):
+ logger.error(f'Failed to load {self.output_tracking}')
+ return ERROR_CODE['FailedToExport'], 'Failed'
+
+ src_folder = Path(self.output_tracking)
+ tgt_folder = Path(self.output_export,
+ self.sub_output_dir.split('/')[-1])
+ src_folder, config_data = load_config(src_folder)
+
+ nerf_writer = NeRFDatasetWriter(config_data.data, tgt_folder, None,
+ None, 'white')
+ nerf_writer.write()
+
+ flame_writer = TrackedFLAMEDatasetWriter(config_data.model,
+ src_folder,
+ tgt_folder,
+ mode='param',
+ epoch=-1)
+ flame_writer.write()
+
+ split_json(tgt_folder)
+
+ end_time = time.time()
+ torch.cuda.empty_cache()
+ logger.info(f'Finished Export. Time: {end_time - start_time:.2f}s')
+
+ return 0, str(tgt_folder)
diff --git a/lam/__init__.py b/lam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/lam/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/lam/datasets/__init__.py b/lam/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..323127c7d93f0a57f90cc8649ee2a67b6b630762
--- /dev/null
+++ b/lam/datasets/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .mixer import MixerDataset
diff --git a/lam/datasets/base.py b/lam/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e300d8517e9617cfa6d8be41fc134d35306924e9
--- /dev/null
+++ b/lam/datasets/base.py
@@ -0,0 +1,90 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+import traceback
+import json
+import numpy as np
+import torch
+from PIL import Image
+from typing import Optional, Union
+from megfile import smart_open, smart_path_join, smart_exists
+
+
+class BaseDataset(torch.utils.data.Dataset, ABC):
+ def __init__(self, root_dirs: str, meta_path: Optional[Union[list, str]]):
+ super().__init__()
+ self.root_dirs = root_dirs
+ self.uids = self._load_uids(meta_path)
+
+ def __len__(self):
+ return len(self.uids)
+
+ @abstractmethod
+ def inner_get_item(self, idx):
+ pass
+
+ def __getitem__(self, idx):
+ try:
+ return self.inner_get_item(idx)
+ except Exception as e:
+ traceback.print_exc()
+ print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
+ # raise e
+ return self.__getitem__((idx + 1) % self.__len__())
+
+ @staticmethod
+ def _load_uids(meta_path: Optional[Union[list, str]]):
+ # meta_path is a json file
+ if isinstance(meta_path, str):
+ with open(meta_path, 'r') as f:
+ uids = json.load(f)
+ else:
+ uids_lst = []
+ max_total = 0
+ for pth, weight in meta_path:
+ with open(pth, 'r') as f:
+ uids = json.load(f)
+ max_total = max(len(uids) / weight, max_total)
+ uids_lst.append([uids, weight, pth])
+ merged_uids = []
+ for uids, weight, pth in uids_lst:
+ repeat = 1
+ if len(uids) < int(weight * max_total):
+ repeat = int(weight * max_total) // len(uids)
+ cur_uids = uids * repeat
+ merged_uids += cur_uids
+ print("Data Path:", pth, "Repeat:", repeat, "Final Length:", len(cur_uids))
+ uids = merged_uids
+ print("Total UIDs:", len(uids))
+ return uids
+
+ @staticmethod
+ def _load_rgba_image(file_path, bg_color: float = 1.0):
+ ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled '''
+ rgba = np.array(Image.open(smart_open(file_path, 'rb')))
+ rgba = torch.from_numpy(rgba).float() / 255.0
+ rgba = rgba.permute(2, 0, 1).unsqueeze(0)
+ rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :])
+ rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
+ return rgb
+
+ @staticmethod
+ def _locate_datadir(root_dirs, uid, locator: str):
+ for root_dir in root_dirs:
+ datadir = smart_path_join(root_dir, uid, locator)
+ if smart_exists(datadir):
+ return root_dir
+ raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
diff --git a/lam/datasets/cam_utils.py b/lam/datasets/cam_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..70653ae2a7f612714f729c73f45e826109b7e0ff
--- /dev/null
+++ b/lam/datasets/cam_utils.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+import torch
+
+"""
+R: (N, 3, 3)
+T: (N, 3)
+E: (N, 4, 4)
+vector: (N, 3)
+"""
+
+
+def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from R and T.
+ Batched I/O.
+ """
+ RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
+ return compose_extrinsic_RT(RT)
+
+
+def compose_extrinsic_RT(RT: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from RT.
+ Batched I/O.
+ """
+ return torch.cat([
+ RT,
+ torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
+ ], dim=1)
+
+
+def decompose_extrinsic_R_T(E: torch.Tensor):
+ """
+ Decompose the standard extrinsic matrix into R and T.
+ Batched I/O.
+ """
+ RT = decompose_extrinsic_RT(E)
+ return RT[:, :, :3], RT[:, :, 3]
+
+
+def decompose_extrinsic_RT(E: torch.Tensor):
+ """
+ Decompose the standard extrinsic matrix into RT.
+ Batched I/O.
+ """
+ return E[:, :3, :]
+
+
+def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
+ assert normed_dist_to_center is not None
+ pivotal_pose = compose_extrinsic_RT(poses[:1])
+ dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
+ if normed_dist_to_center == 'auto' else normed_dist_to_center
+
+ # compute camera norm (new version)
+ canonical_camera_extrinsics = torch.tensor([[
+ [1, 0, 0, 0],
+ [0, 0, -1, -dist_to_center],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1],
+ ]], dtype=torch.float32)
+ pivotal_pose_inv = torch.inverse(pivotal_pose)
+ camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
+
+ # normalize all views
+ poses = compose_extrinsic_RT(poses)
+ poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
+ poses = decompose_extrinsic_RT(poses)
+
+ if ret_transform:
+ return poses, camera_norm_matrix.squeeze(dim=0)
+ return poses
+
+
+def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
+ """
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ Return batched fx, fy, cx, cy
+ """
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
+ fx, fy = fx / width, fy / height
+ cx, cy = cx / width, cy / height
+ return fx, fy, cx, cy
+
+
+def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
+ """
+ RT: (N, 3, 4)
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ """
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
+ return torch.cat([
+ RT.reshape(-1, 12),
+ fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
+ ], dim=-1)
+
+
+def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
+ """
+ RT: (N, 3, 4)
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ """
+ E = compose_extrinsic_RT(RT)
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
+ I = torch.stack([
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
+ ], dim=1)
+ return torch.cat([
+ E.reshape(-1, 16),
+ I.reshape(-1, 9),
+ ], dim=-1)
+
+
+def center_looking_at_camera_pose(
+ camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
+ device: torch.device = torch.device('cpu'),
+ ):
+ """
+ camera_position: (M, 3)
+ look_at: (3)
+ up_world: (3)
+ return: (M, 3, 4)
+ """
+ # by default, looking at the origin and world up is pos-z
+ if look_at is None:
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
+ if up_world is None:
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
+
+ z_axis = camera_position - look_at
+ z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
+ x_axis = torch.cross(up_world, z_axis)
+ x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
+ y_axis = torch.cross(z_axis, x_axis)
+ y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
+ return extrinsics
+
+
+def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
+ """
+ n_views: number of surrounding views
+ radius: camera dist to center
+ height: height of the camera
+ return: (M, 3, 4)
+ """
+ assert n_views > 0
+ assert radius > 0
+
+ theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
+ projected_radius = math.sqrt(radius ** 2 - height ** 2)
+ x = torch.cos(theta) * projected_radius
+ y = torch.sin(theta) * projected_radius
+ z = torch.full((n_views,), height, device=device)
+
+ camera_positions = torch.stack([x, y, z], dim=1)
+ extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
+
+ return extrinsics
+
+
+def create_intrinsics(
+ f: float,
+ c: float = None, cx: float = None, cy: float = None,
+ w: float = 1., h: float = 1.,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device('cpu'),
+ ):
+ """
+ return: (3, 2)
+ """
+ fx = fy = f
+ if c is not None:
+ assert cx is None and cy is None, "c and cx/cy cannot be used together"
+ cx = cy = c
+ else:
+ assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
+ fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
+ intrinsics = torch.tensor([
+ [fx, fy],
+ [cx, cy],
+ [w, h],
+ ], dtype=dtype, device=device)
+ return intrinsics
diff --git a/lam/datasets/mixer.py b/lam/datasets/mixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d3c9f7964eb1314d8d8739de57c130df4bbae3
--- /dev/null
+++ b/lam/datasets/mixer.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from functools import partial
+import torch
+
+__all__ = ['MixerDataset']
+
+
+class MixerDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ split: str,
+ subsets: dict,
+ **dataset_kwargs,
+ ):
+ subsets = [e for e in subsets if e["meta_path"][split] is not None]
+ self.subsets = [
+ self._dataset_fn(subset, split)(**dataset_kwargs)
+ for subset in subsets
+ ]
+ self.virtual_lens = [
+ math.ceil(subset_config['sample_rate'] * len(subset_obj))
+ for subset_config, subset_obj in zip(subsets, self.subsets)
+ ]
+
+ @staticmethod
+ def _dataset_fn(subset_config: dict, split: str):
+ name = subset_config['name']
+
+ dataset_cls = None
+ if name == "exavatar":
+ from .exavatar import ExAvatarDataset
+ dataset_cls = ExAvatarDataset
+ elif name == "humman":
+ from .humman import HuMManDataset
+ dataset_cls = HuMManDataset
+ elif name == "humman_ori":
+ from .humman_ori import HuMManOriDataset
+ dataset_cls = HuMManOriDataset
+ elif name == "static_human":
+ from .static_human import StaticHumanDataset
+ dataset_cls = StaticHumanDataset
+ elif name == "singleview_human":
+ from .singleview_human import SingleViewHumanDataset
+ dataset_cls = SingleViewHumanDataset
+ elif name == "singleview_square_human":
+ from .singleview_square_human import SingleViewSquareHumanDataset
+ dataset_cls = SingleViewSquareHumanDataset
+ elif name == "bedlam":
+ from .bedlam import BedlamDataset
+ dataset_cls = BedlamDataset
+ elif name == "dna_human":
+ from .dna import DNAHumanDataset
+ dataset_cls = DNAHumanDataset
+ elif name == "video_human":
+ from .video_human import VideoHumanDataset
+ dataset_cls = VideoHumanDataset
+ elif name == "video_head":
+ from .video_head import VideoHeadDataset
+ dataset_cls = VideoHeadDataset
+ elif name == "video_head_gagtrack":
+ from .video_head_gagtrack import VideoHeadGagDataset
+ dataset_cls = VideoHeadGagDataset
+ elif name == "objaverse":
+ from .objaverse import ObjaverseDataset
+ dataset_cls = ObjaverseDataset
+ # elif name == 'mvimgnet':
+ # from .mvimgnet import MVImgNetDataset
+ # dataset_cls = MVImgNetDataset
+ else:
+ raise NotImplementedError(f"Dataset {name} not implemented")
+ print("==="*16*3, "\nUse dataset loader:", name, "\n"+"==="*3*16)
+
+ return partial(
+ dataset_cls,
+ root_dirs=subset_config['root_dirs'],
+ meta_path=subset_config['meta_path'][split],
+ )
+
+ def __len__(self):
+ return sum(self.virtual_lens)
+
+ def __getitem__(self, idx):
+ subset_idx = 0
+ virtual_idx = idx
+ while virtual_idx >= self.virtual_lens[subset_idx]:
+ virtual_idx -= self.virtual_lens[subset_idx]
+ subset_idx += 1
+ real_idx = virtual_idx % len(self.subsets[subset_idx])
+ return self.subsets[subset_idx][real_idx]
diff --git a/lam/datasets/video_head.py b/lam/datasets/video_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f31baa002704afe150c8247220fa26d6356356a3
--- /dev/null
+++ b/lam/datasets/video_head.py
@@ -0,0 +1,655 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import defaultdict
+import os
+import glob
+from typing import Union
+import random
+import numpy as np
+import torch
+# from megfile import smart_path_join, smart_open
+import json
+from PIL import Image
+import cv2
+
+from lam.datasets.base import BaseDataset
+from lam.datasets.cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse
+from lam.utils.proxy import no_proxy
+from typing import Optional, Union
+
+__all__ = ['VideoHeadDataset']
+
+
+class VideoHeadDataset(BaseDataset):
+
+ def __init__(self, root_dirs: str, meta_path: Optional[Union[str, list]],
+ sample_side_views: int,
+ render_image_res_low: int, render_image_res_high: int, render_region_size: int,
+ source_image_res: int,
+ repeat_num=1,
+ crop_range_ratio_hw=[1.0, 1.0],
+ aspect_standard=1.0, # h/w
+ enlarge_ratio=[0.8, 1.2],
+ debug=False,
+ is_val=False,
+ **kwargs):
+ super().__init__(root_dirs, meta_path)
+ self.sample_side_views = sample_side_views
+ self.render_image_res_low = render_image_res_low
+ self.render_image_res_high = render_image_res_high
+ if not (isinstance(render_region_size, list) or isinstance(render_region_size, tuple)):
+ render_region_size = render_region_size, render_region_size # [H, W]
+ self.render_region_size = render_region_size
+ self.source_image_res = source_image_res
+
+ self.uids = self.uids * repeat_num
+ self.crop_range_ratio_hw = crop_range_ratio_hw
+ self.debug = debug
+ self.aspect_standard = aspect_standard
+
+ assert self.render_image_res_low == self.render_image_res_high
+ self.render_image_res = self.render_image_res_low
+ self.enlarge_ratio = enlarge_ratio
+ print(f"VideoHeadDataset, data_len:{len(self.uids)}, repeat_num:{repeat_num}, debug:{debug}, is_val:{is_val}")
+ self.multiply = kwargs.get("multiply", 14)
+ # set data deterministic
+ self.is_val = is_val
+
+ @staticmethod
+ def _load_pose(frame_info, transpose_R=False):
+ c2w = torch.eye(4)
+ c2w = np.array(frame_info["transform_matrix"])
+ c2w[:3, 1:3] *= -1
+ c2w = torch.FloatTensor(c2w)
+ """
+ if transpose_R:
+ w2c = torch.inverse(c2w)
+ w2c[:3, :3] = w2c[:3, :3].transpose(1, 0).contiguous()
+ c2w = torch.inverse(w2c)
+ """
+
+ intrinsic = torch.eye(4)
+ intrinsic[0, 0] = frame_info["fl_x"]
+ intrinsic[1, 1] = frame_info["fl_y"]
+ intrinsic[0, 2] = frame_info["cx"]
+ intrinsic[1, 2] = frame_info["cy"]
+ intrinsic = intrinsic.float()
+
+ return c2w, intrinsic
+
+ def img_center_padding(self, img_np, pad_ratio):
+
+ ori_w, ori_h = img_np.shape[:2]
+
+ w = round((1 + pad_ratio) * ori_w)
+ h = round((1 + pad_ratio) * ori_h)
+
+ if len(img_np.shape) > 2:
+ img_pad_np = np.zeros((w, h, img_np.shape[2]), dtype=np.uint8)
+ else:
+ img_pad_np = np.zeros((w, h), dtype=np.uint8)
+ offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2
+ img_pad_np[offset_h: offset_h + img_np.shape[0]:, offset_w: offset_w + img_np.shape[1]] = img_np
+
+ return img_pad_np
+
+ def resize_image_keepaspect_np(self, img, max_tgt_size):
+ """
+ similar to ImageOps.contain(img_pil, (img_size, img_size)) # keep the same aspect ratio
+ """
+ h, w = img.shape[:2]
+ ratio = max_tgt_size / max(h, w)
+ new_h, new_w = round(h * ratio), round(w * ratio)
+ return cv2.resize(img, dsize=(new_w, new_h), interpolation=cv2.INTER_AREA)
+
+ def center_crop_according_to_mask(self, img, mask, aspect_standard, enlarge_ratio):
+ """
+ img: [H, W, 3]
+ mask: [H, W]
+ """
+ ys, xs = np.where(mask > 0)
+
+ if len(xs) == 0 or len(ys) == 0:
+ raise Exception("empty mask")
+
+ x_min = np.min(xs)
+ x_max = np.max(xs)
+ y_min = np.min(ys)
+ y_max = np.max(ys)
+
+ center_x, center_y = img.shape[1]//2, img.shape[0]//2
+
+ half_w = max(abs(center_x - x_min), abs(center_x - x_max))
+ half_h = max(abs(center_y - y_min), abs(center_y - y_max))
+ aspect = half_h / half_w
+
+ if aspect >= aspect_standard:
+ half_w = round(half_h / aspect_standard)
+ else:
+ half_h = round(half_w * aspect_standard)
+
+ if abs(enlarge_ratio[0] - 1) > 0.01 or abs(enlarge_ratio[1] - 1) > 0.01:
+ enlarge_ratio_min, enlarge_ratio_max = enlarge_ratio
+ enlarge_ratio_max_real = min(center_y / half_h, center_x / half_w)
+ enlarge_ratio_max = min(enlarge_ratio_max_real, enlarge_ratio_max)
+ enlarge_ratio_min = min(enlarge_ratio_max_real, enlarge_ratio_min)
+ enlarge_ratio_cur = np.random.rand() * (enlarge_ratio_max - enlarge_ratio_min) + enlarge_ratio_min
+ half_h, half_w = round(enlarge_ratio_cur * half_h), round(enlarge_ratio_cur * half_w)
+
+ assert half_h <= center_y
+ assert half_w <= center_x
+ assert abs(half_h / half_w - aspect_standard) < 0.03
+
+ offset_x = center_x - half_w
+ offset_y = center_y - half_h
+
+ new_img = img[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+ new_mask = mask[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+
+ return new_img, new_mask, offset_x, offset_y
+
+ def load_rgb_image_with_aug_bg(self, rgb_path, mask_path, bg_color, pad_ratio, max_tgt_size, aspect_standard, enlarge_ratio,
+ render_tgt_size, multiply, intr):
+ rgb = np.array(Image.open(rgb_path))
+ interpolation = cv2.INTER_AREA
+ if rgb.shape[0] != 1024 and rgb.shape[0] == rgb.shape[1]:
+ rgb = cv2.resize(rgb, (1024, 1024), interpolation=interpolation)
+ if pad_ratio > 0:
+ rgb = self.img_center_padding(rgb, pad_ratio)
+
+ rgb = rgb / 255.0
+ if mask_path is not None:
+ if os.path.exists(mask_path):
+ mask = np.array(Image.open(mask_path)) > 180
+ if len(mask.shape) == 3:
+ mask = mask[..., 0]
+ assert pad_ratio == 0
+ # if pad_ratio > 0:
+ # mask = self.img_center_padding(mask, pad_ratio)
+ # mask = mask / 255.0
+ else:
+ # print("no mask file")
+ mask = (rgb >= 0.99).sum(axis=2) == 3
+ mask = np.logical_not(mask)
+ # erode
+ mask = (mask * 255).astype(np.uint8)
+ kernel_size, iterations = 3, 7
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
+ mask = cv2.erode(mask, kernel, iterations=iterations) / 255.0
+ else:
+ # rgb: [H, W, 4]
+ assert rgb.shape[2] == 4
+ mask = rgb[:, :, 3] # [H, W]
+ if len(mask.shape) > 2:
+ mask = mask[:, :, 0]
+
+ mask = (mask > 0.5).astype(np.float32)
+ rgb = rgb[:, :, :3] * mask[:, :, None] + bg_color * (1 - mask[:, :, None])
+
+ # crop image to enlarge face area.
+ try:
+ rgb, mask, offset_x, offset_y = self.center_crop_according_to_mask(rgb, mask, aspect_standard, enlarge_ratio)
+ except Exception as ex:
+ print(rgb_path, mask_path, ex)
+
+ intr[0, 2] -= offset_x
+ intr[1, 2] -= offset_y
+
+ # resize to render_tgt_size for training
+ tgt_hw_size, ratio_y, ratio_x = self.calc_new_tgt_size_by_aspect(cur_hw=rgb.shape[:2],
+ aspect_standard=aspect_standard,
+ tgt_size=render_tgt_size, multiply=multiply)
+ rgb = cv2.resize(rgb, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=interpolation)
+ mask = cv2.resize(mask, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=interpolation)
+ intr = self.scale_intrs(intr, ratio_x=ratio_x, ratio_y=ratio_y)
+
+ assert abs(intr[0, 2] * 2 - rgb.shape[1]) < 2.5, f"{intr[0, 2] * 2}, {rgb.shape[1]}"
+ assert abs(intr[1, 2] * 2 - rgb.shape[0]) < 2.5, f"{intr[1, 2] * 2}, {rgb.shape[0]}"
+ intr[0, 2] = rgb.shape[1] // 2
+ intr[1, 2] = rgb.shape[0] // 2
+
+ rgb = torch.from_numpy(rgb).float().permute(2, 0, 1).unsqueeze(0)
+ mask = torch.from_numpy(mask[:, :, None]).float().permute(2, 0, 1).unsqueeze(0)
+
+ return rgb, mask, intr
+
+ def scale_intrs(self, intrs, ratio_x, ratio_y):
+ if len(intrs.shape) >= 3:
+ intrs[:, 0] = intrs[:, 0] * ratio_x
+ intrs[:, 1] = intrs[:, 1] * ratio_y
+ else:
+ intrs[0] = intrs[0] * ratio_x
+ intrs[1] = intrs[1] * ratio_y
+ return intrs
+
+ def uniform_sample_in_chunk(self, sample_num, sample_data):
+ chunks = np.array_split(sample_data, sample_num)
+ select_list = []
+ for chunk in chunks:
+ select_list.append(np.random.choice(chunk))
+ return select_list
+
+ def uniform_sample_in_chunk_det(self, sample_num, sample_data):
+ chunks = np.array_split(sample_data, sample_num)
+ select_list = []
+ for chunk in chunks:
+ select_list.append(chunk[len(chunk)//2])
+ return select_list
+
+ def calc_new_tgt_size(self, cur_hw, tgt_size, multiply):
+ ratio = tgt_size / min(cur_hw)
+ tgt_size = int(ratio * cur_hw[0]), int(ratio * cur_hw[1])
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+ def calc_new_tgt_size_by_aspect(self, cur_hw, aspect_standard, tgt_size, multiply):
+ assert abs(cur_hw[0] / cur_hw[1] - aspect_standard) < 0.03
+ tgt_size = tgt_size * aspect_standard, tgt_size
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+ def load_flame_params(self, flame_file_path, teeth_bs=None):
+
+ flame_param = dict(np.load(flame_file_path), allow_pickle=True)
+
+ flame_param_tensor = {}
+ flame_param_tensor['expr'] = torch.FloatTensor(flame_param['expr'])[0]
+ flame_param_tensor['rotation'] = torch.FloatTensor(flame_param['rotation'])[0]
+ flame_param_tensor['neck_pose'] = torch.FloatTensor(flame_param['neck_pose'])[0]
+ flame_param_tensor['jaw_pose'] = torch.FloatTensor(flame_param['jaw_pose'])[0]
+ flame_param_tensor['eyes_pose'] = torch.FloatTensor(flame_param['eyes_pose'])[0]
+ flame_param_tensor['translation'] = torch.FloatTensor(flame_param['translation'])[0]
+ if teeth_bs is not None:
+ flame_param_tensor['teeth_bs'] = torch.FloatTensor(teeth_bs)
+ # flame_param_tensor['expr'] = torch.cat([flame_param_tensor['expr'], flame_param_tensor['teeth_bs']], dim=0)
+
+ return flame_param_tensor
+
+ @no_proxy
+ def inner_get_item(self, idx):
+ """
+ Loaded contents:
+ rgbs: [M, 3, H, W]
+ poses: [M, 3, 4], [R|t]
+ intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
+ """
+ crop_ratio_h, crop_ratio_w = self.crop_range_ratio_hw
+
+ uid = self.uids[idx]
+ if len(uid.split('/')) == 1:
+ uid = os.path.join(self.root_dirs, uid)
+ mode_str = "train" if not self.is_val else "test"
+ transforms_json = os.path.join(uid, f"transforms_{mode_str}.json")
+
+ with open(transforms_json) as fp:
+ data = json.load(fp)
+ cor_flame_path = transforms_json.replace('transforms_{}.json'.format(mode_str),'canonical_flame_param.npz')
+ flame_param = np.load(cor_flame_path)
+ shape_param = torch.FloatTensor(flame_param['shape'])
+ # data['static_offset'] = flame_param['static_offset']
+
+ all_frames = data["frames"]
+
+ sample_total_views = self.sample_side_views + 1
+ if len(all_frames) >= self.sample_side_views:
+ if not self.is_val:
+ if np.random.rand() < 0.7 and len(all_frames) > sample_total_views:
+ frame_id_list = self.uniform_sample_in_chunk(sample_total_views, np.arange(len(all_frames)))
+ else:
+ replace = len(all_frames) < sample_total_views
+ frame_id_list = np.random.choice(len(all_frames), size=sample_total_views, replace=replace)
+ else:
+ if len(all_frames) > sample_total_views:
+ frame_id_list = self.uniform_sample_in_chunk_det(sample_total_views, np.arange(len(all_frames)))
+ else:
+ frame_id_list = np.random.choice(len(all_frames), size=sample_total_views, replace=True)
+ else:
+ if not self.is_val:
+ replace = len(all_frames) < sample_total_views
+ frame_id_list = np.random.choice(len(all_frames), size=sample_total_views, replace=replace)
+ else:
+ if len(all_frames) > 1:
+ frame_id_list = np.linspace(0, len(all_frames) - 1, num=sample_total_views, endpoint=True)
+ frame_id_list = [round(e) for e in frame_id_list]
+ else:
+ frame_id_list = [0 for i in range(sample_total_views)]
+
+ cam_id_list = frame_id_list
+
+ assert self.sample_side_views + 1 == len(frame_id_list)
+
+ # source images
+ c2ws, intrs, rgbs, bg_colors, masks = [], [], [], [], []
+ flame_params = []
+ teeth_bs_pth = os.path.join(uid, "tracked_teeth_bs.npz")
+ use_teeth = False
+ if os.path.exists(teeth_bs_pth) and use_teeth:
+ teeth_bs_lst = np.load(teeth_bs_pth)['expr_teeth']
+ else:
+ teeth_bs_lst = None
+ for cam_id, frame_id in zip(cam_id_list, frame_id_list):
+ frame_info = all_frames[frame_id]
+ frame_path = os.path.join(uid, frame_info["file_path"])
+ if 'nersemble' in frame_path or "tiktok_v34" in frame_path:
+ mask_path = os.path.join(uid, frame_info["fg_mask_path"])
+ else:
+ mask_path = os.path.join(uid, frame_info["fg_mask_path"]).replace("/export/", "/mask/").replace("/fg_masks/", "/mask/").replace(".png", ".jpg")
+ if not os.path.exists(mask_path):
+ mask_path = os.path.join(uid, frame_info["fg_mask_path"])
+
+ teeth_bs = teeth_bs_lst[frame_id] if teeth_bs_lst is not None else None
+ flame_path = os.path.join(uid, frame_info["flame_param_path"])
+ flame_param = self.load_flame_params(flame_path, teeth_bs)
+
+ # if cam_id == 0:
+ # shape_param = flame_param["betas"]
+
+ c2w, ori_intrinsic = self._load_pose(frame_info, transpose_R="nersemble" in frame_path)
+
+ bg_color = random.choice([0.0, 0.5, 1.0]) # 1.0
+ # if self.is_val:
+ # bg_color = 1.0
+ rgb, mask, intrinsic = self.load_rgb_image_with_aug_bg(frame_path, mask_path=mask_path,
+ bg_color=bg_color,
+ pad_ratio=0,
+ max_tgt_size=None,
+ aspect_standard=self.aspect_standard,
+ enlarge_ratio=self.enlarge_ratio if (not self.is_val) or ("nersemble" in frame_path) else [1.0, 1.0],
+ render_tgt_size=self.render_image_res,
+ multiply=16,
+ intr=ori_intrinsic.clone())
+ c2ws.append(c2w)
+ rgbs.append(rgb)
+ bg_colors.append(bg_color)
+ intrs.append(intrinsic)
+ flame_params.append(flame_param)
+ masks.append(mask)
+
+ c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
+ intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
+ rgbs = torch.cat(rgbs, dim=0) # [N, 3, H, W]
+ bg_colors = torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3) # [N, 3]
+ masks = torch.cat(masks, dim=0) # [N, 1, H, W]
+
+ flame_params_tmp = defaultdict(list)
+ for flame in flame_params:
+ for k, v in flame.items():
+ flame_params_tmp[k].append(v)
+ for k, v in flame_params_tmp.items():
+ flame_params_tmp[k] = torch.stack(v)
+ flame_params = flame_params_tmp
+ # TODO check different betas for same person
+ flame_params["betas"] = shape_param
+
+ # reference images
+ prob_refidx = np.ones(self.sample_side_views + 1)
+ if not self.is_val:
+ prob_refidx[0] = 0.5 # front_prob
+ else:
+ prob_refidx[0] = 1.0
+ # print(frame_id_list, kinect_color_list, prob_refidx[0])
+ prob_refidx[1:] = (1 - prob_refidx[0]) / len(prob_refidx[1:])
+ ref_idx = np.random.choice(self.sample_side_views + 1, p=prob_refidx)
+ cam_id_source_list = cam_id_list[ref_idx: ref_idx + 1]
+ frame_id_source_list = frame_id_list[ref_idx: ref_idx + 1]
+
+ source_c2ws, source_intrs, source_rgbs, source_flame_params = [], [], [], []
+ for cam_id, frame_id in zip(cam_id_source_list, frame_id_source_list):
+ frame_info = all_frames[frame_id]
+ frame_path = os.path.join(uid, frame_info["file_path"])
+ if 'nersemble' in frame_path:
+ mask_path = os.path.join(uid, frame_info["fg_mask_path"])
+ else:
+ mask_path = os.path.join(uid, frame_info["fg_mask_path"]).replace("/export/", "/mask/").replace("/fg_masks/", "/mask/").replace(".png", ".jpg")
+ flame_path = os.path.join(uid, frame_info["flame_param_path"])
+
+ teeth_bs = teeth_bs_lst[frame_id] if teeth_bs_lst is not None else None
+ flame_param = self.load_flame_params(flame_path, teeth_bs)
+
+ c2w, ori_intrinsic = self._load_pose(frame_info)
+
+ # bg_color = 1.0
+ # bg_color = 0.0
+ bg_color = random.choice([0.0, 0.5, 1.0]) # 1.
+ rgb, mask, intrinsic = self.load_rgb_image_with_aug_bg(frame_path, mask_path=mask_path,
+ bg_color=bg_color,
+ pad_ratio=0,
+ max_tgt_size=None,
+ aspect_standard=self.aspect_standard,
+ enlarge_ratio=self.enlarge_ratio if (not self.is_val) or ("nersemble" in frame_path) else [1.0, 1.0],
+ render_tgt_size=self.source_image_res,
+ multiply=self.multiply,
+ intr=ori_intrinsic.clone())
+
+ source_c2ws.append(c2w)
+ source_intrs.append(intrinsic)
+ source_rgbs.append(rgb)
+ source_flame_params.append(flame_param)
+
+ source_c2ws = torch.stack(source_c2ws, dim=0)
+ source_intrs = torch.stack(source_intrs, dim=0)
+ source_rgbs = torch.cat(source_rgbs, dim=0)
+
+ flame_params_tmp = defaultdict(list)
+ for flame in source_flame_params:
+ for k, v in flame.items():
+ flame_params_tmp['source_'+k].append(v)
+ for k, v in flame_params_tmp.items():
+ flame_params_tmp[k] = torch.stack(v)
+ source_flame_params = flame_params_tmp
+ # TODO check different betas for same person
+ source_flame_params["source_betas"] = shape_param
+
+ render_image = rgbs
+ render_mask = masks
+ tgt_size = render_image.shape[2:4] # [H, W]
+ assert abs(intrs[0, 0, 2] * 2 - render_image.shape[3]) <= 1.1, f"{intrs[0, 0, 2] * 2}, {render_image.shape}"
+ assert abs(intrs[0, 1, 2] * 2 - render_image.shape[2]) <= 1.1, f"{intrs[0, 1, 2] * 2}, {render_image.shape}"
+
+ ret = {
+ 'uid': uid,
+ 'source_c2ws': source_c2ws, # [N1, 4, 4]
+ 'source_intrs': source_intrs, # [N1, 4, 4]
+ 'source_rgbs': source_rgbs.clamp(0, 1), # [N1, 3, H, W]
+ 'render_image': render_image.clamp(0, 1), # [N, 3, H, W]
+ 'render_mask': render_mask.clamp(0, 1), #[ N, 1, H, W]
+ 'c2ws': c2ws, # [N, 4, 4]
+ 'intrs': intrs, # [N, 4, 4]
+ 'render_full_resolutions': torch.tensor([tgt_size], dtype=torch.float32).repeat(self.sample_side_views + 1, 1), # [N, 2]
+ 'render_bg_colors': bg_colors, # [N, 3]
+ 'pytorch3d_transpose_R': torch.Tensor(["nersemble" in frame_path]), # [1]
+ }
+
+ #['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas']
+ # 'flame_params': flame_params, # dict: body_pose:[N, 21, 3],
+ ret.update(flame_params)
+ ret.update(source_flame_params)
+
+ return ret
+
+def gen_valid_id_json():
+ root_dir = "./train_data/vfhq_vhap/export"
+ save_path = "./train_data/vfhq_vhap/label/valid_id_list.json"
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ valid_id_list = []
+ for file in os.listdir(root_dir):
+ if not file.startswith("."):
+ valid_id_list.append(file)
+ print(len(valid_id_list), valid_id_list[:2])
+ with open(save_path, "w") as fp:
+ json.dump(valid_id_list, fp)
+
+
+def gen_valid_id_json():
+ root_dir = "./train_data/vfhq_vhap/export"
+ mask_root_dir = "./train_data/vfhq_vhap/mask"
+ save_path = "./train_data/vfhq_vhap/label/valid_id_list.json"
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ valid_id_list = []
+ for file in os.listdir(root_dir):
+ if not file.startswith(".") and ".txt" not in file:
+ valid_id_list.append(file)
+ print("raw:", len(valid_id_list), valid_id_list[:2])
+
+ mask_valid_id_list = []
+ for file in os.listdir(mask_root_dir):
+ if not file.startswith(".") and ".txt" not in file:
+ mask_valid_id_list.append(file)
+ print("mask:", len(mask_valid_id_list), mask_valid_id_list[:2])
+
+ valid_id_list = list(set(valid_id_list).intersection(set(mask_valid_id_list)))
+ print("intesection:", len(mask_valid_id_list), mask_valid_id_list[:2])
+
+ with open(save_path, "w") as fp:
+ json.dump(valid_id_list, fp)
+
+ save_train_path = "./train_data/vfhq_vhap/label/valid_id_train_list.json"
+ save_val_path = "./train_data/vfhq_vhap/label/valid_id_val_list.json"
+ valid_id_list = sorted(valid_id_list)
+ idxs = np.linspace(0, len(valid_id_list)-1, num=20, endpoint=True).astype(np.int64)
+ valid_id_train_list = []
+ valid_id_val_list = []
+ for i in range(len(valid_id_list)):
+ if i in idxs:
+ valid_id_val_list.append(valid_id_list[i])
+ else:
+ valid_id_train_list.append(valid_id_list[i])
+
+ print(len(valid_id_train_list), len(valid_id_val_list), valid_id_val_list)
+ with open(save_train_path, "w") as fp:
+ json.dump(valid_id_train_list, fp)
+
+ with open(save_val_path, "w") as fp:
+ json.dump(valid_id_val_list, fp)
+
+
+if __name__ == "__main__":
+ import trimesh
+ import cv2
+ root_dir = "./train_data/vfhq_vhap/export"
+ meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json"
+ dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=15,
+ render_image_res_low=512, render_image_res_high=512,
+ render_region_size=(512, 512), source_image_res=512,
+ enlarge_ratio=[0.8, 1.2],
+ debug=False, is_val=False)
+
+ from lam.models.rendering.flame_model.flame import FlameHeadSubdivided
+
+ # subdivided flame
+ subdivide = 2
+ flame_sub_model = FlameHeadSubdivided(
+ 300,
+ 100,
+ add_teeth=True,
+ add_shoulder=False,
+ flame_model_path='pretrained_models/human_model_files/flame_assets/flame/flame2023.pkl',
+ flame_lmk_embedding_path="pretrained_models/human_model_files/flame_assets/flame/landmark_embedding_with_eyes.npy",
+ flame_template_mesh_path="pretrained_models/human_model_files/flame_assets/flame/head_template_mesh.obj",
+ flame_parts_path="pretrained_models/human_model_files/flame_assets/flame/FLAME_masks.pkl",
+ subdivide_num=subdivide,
+ teeth_bs_flag=False,
+ ).cuda()
+
+ source_key = "source_rgbs"
+ render_key = "render_image"
+
+ for idx, data in enumerate(dataset):
+ import boxx
+ boxx.tree(data)
+ if idx > 0:
+ exit(0)
+ os.makedirs("debug_vis/dataloader", exist_ok=True)
+ for i in range(data[source_key].shape[0]):
+ cv2.imwrite(f"debug_vis/dataloader/{source_key}_{i}_b{idx}.jpg", ((data[source_key][i].permute(1, 2, 0).numpy()[:, :, (2, 1, 0)] * 255).astype(np.uint8)))
+
+ for i in range(data[render_key].shape[0]):
+ cv2.imwrite(f"debug_vis/dataloader/rgbs{i}_b{idx}.jpg", ((data[render_key][i].permute(1, 2, 0).numpy()[:, :, (2, 1, 0)] * 255).astype(np.uint8)))
+
+
+ save_root = "./debug_vis/dataloader"
+ os.makedirs(save_root, exist_ok=True)
+
+ shape = data['betas'].to('cuda')
+ flame_param = {}
+ flame_param['expr'] = data['expr'].to('cuda')
+ flame_param['rotation'] = data['rotation'].to('cuda')
+ flame_param['neck'] = data['neck_pose'].to('cuda')
+ flame_param['jaw'] = data['jaw_pose'].to('cuda')
+ flame_param['eyes'] = data['eyes_pose'].to('cuda')
+ flame_param['translation'] = data['translation'].to('cuda')
+
+
+ v_cano = flame_sub_model.get_cano_verts(
+ shape.unsqueeze(0)
+ )
+ ret = flame_sub_model.animation_forward(
+ v_cano.repeat(flame_param['expr'].shape[0], 1, 1),
+ shape.unsqueeze(0).repeat(flame_param['expr'].shape[0], 1),
+ flame_param['expr'],
+ flame_param['rotation'],
+ flame_param['neck'],
+ flame_param['jaw'],
+ flame_param['eyes'],
+ flame_param['translation'],
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=True,
+ # static_offset=batch_data['static_offset'].to('cuda'),
+ static_offset=None,
+ )
+
+ import boxx
+ boxx.tree(data)
+ boxx.tree(ret)
+
+ for i in range(ret["animated"].shape[0]):
+ mesh = trimesh.Trimesh()
+ mesh.vertices = np.array(ret["animated"][i].cpu().squeeze())
+ mesh.faces = np.array(flame_sub_model.faces.cpu().squeeze())
+ mesh.export(f'{save_root}/animated_sub{subdivide}_{i}.obj')
+
+ intr = data["intrs"][i]
+ from lam.models.rendering.utils.vis_utils import render_mesh
+ cam_param = {"focal": torch.tensor([intr[0, 0], intr[1, 1]]),
+ "princpt": torch.tensor([intr[0, 2], intr[1, 2]])}
+ render_shape = data[render_key].shape[2:] # int(cam_param['princpt'][1]* 2), int(cam_param['princpt'][0] * 2)
+
+ face = flame_sub_model.faces.cpu().squeeze().numpy()
+ vertices = ret["animated"][i].cpu().squeeze()
+
+ c2ws = data["c2ws"][i]
+ w2cs = torch.inverse(c2ws)
+ if data['pytorch3d_transpose_R'][0] > 0:
+ R = w2cs[:3, :3].transpose(1, 0)
+ else:
+ R = w2cs[:3, :3]
+ T = w2cs[:3, 3]
+ vertices = vertices @ R + T
+ mesh_render, is_bkg = render_mesh(vertices, face, cam_param=cam_param,
+ bkg=np.ones((render_shape[0],render_shape[1], 3), dtype=np.float32) * 255,
+ return_bg_mask=True)
+
+ rgb_mesh = mesh_render.astype(np.uint8)
+ t_image = (data[render_key][i].permute(1, 2, 0)*255).numpy().astype(np.uint8)
+
+ blend_ratio = 0.7
+ vis_img = np.concatenate([rgb_mesh, t_image, (blend_ratio * rgb_mesh + (1 - blend_ratio) * t_image).astype(np.uint8)], axis=1)
+ cam_idx = int(data.get('cam_idxs', [i for j in range(16)])[i])
+
+ cv2.imwrite(os.path.join(save_root, f"render_{cam_idx}.jpg"), vis_img[:, :, (2, 1, 0)])
diff --git a/lam/launch.py b/lam/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90f428d9904a9a4d869c4e5978ec206a6fdf2c4
--- /dev/null
+++ b/lam/launch.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+
+from lam.runners import REGISTRY_RUNNERS
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description='lam launcher')
+ parser.add_argument('runner', type=str, help='Runner to launch')
+ args, unknown = parser.parse_known_args()
+
+ if args.runner not in REGISTRY_RUNNERS:
+ raise ValueError('Runner {} not found'.format(args.runner))
+
+ RunnerClass = REGISTRY_RUNNERS[args.runner]
+ with RunnerClass() as runner:
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lam/losses/__init__.py b/lam/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8da8292b9982cddbfaf84ad3ea74b4bfa9925d
--- /dev/null
+++ b/lam/losses/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .pixelwise import *
+from .perceptual import *
+from .tvloss import *
diff --git a/lam/losses/perceptual.py b/lam/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b0ff579b8c8f8aa041f3d4074b57b68f5af0c33
--- /dev/null
+++ b/lam/losses/perceptual.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+__all__ = ['LPIPSLoss']
+
+
+class LPIPSLoss(nn.Module):
+ """
+ Compute LPIPS loss between two images.
+ """
+
+ def __init__(self, device, prefech: bool = False):
+ super().__init__()
+ self.device = device
+ self.cached_models = {}
+ if prefech:
+ self.prefetch_models()
+
+ def _get_model(self, model_name: str):
+ if model_name not in self.cached_models:
+ import warnings
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=UserWarning)
+ import lpips
+ _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
+ _model = torch.compile(_model)
+ self.cached_models[model_name] = _model
+ return self.cached_models[model_name]
+
+ def prefetch_models(self):
+ _model_names = ['alex', 'vgg']
+ for model_name in _model_names:
+ self._get_model(model_name)
+
+ def forward(self, x, y, is_training: bool = True, conf_sigma=None, only_sym_conf=False):
+ """
+ Assume images are 0-1 scaled and channel first.
+
+ Args:
+ x: [N, M, C, H, W]
+ y: [N, M, C, H, W]
+ is_training: whether to use VGG or AlexNet.
+
+ Returns:
+ Mean-reduced LPIPS loss across batch.
+ """
+ model_name = 'vgg' if is_training else 'alex'
+ loss_fn = self._get_model(model_name)
+ EPS = 1e-7
+ if len(x.shape) == 5:
+ N, M, C, H, W = x.shape
+ x = x.reshape(N*M, C, H, W)
+ y = y.reshape(N*M, C, H, W)
+ image_loss = loss_fn(x, y, normalize=True)
+ image_loss = image_loss.mean(dim=[1, 2, 3])
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
+ all_loss = batch_loss.mean()
+ else:
+ image_loss = loss_fn(x, y, normalize=True)
+ if conf_sigma is not None:
+ image_loss = image_loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
+ image_loss = image_loss.mean(dim=[1, 2, 3])
+ all_loss = image_loss.mean()
+ return all_loss
diff --git a/lam/losses/pixelwise.py b/lam/losses/pixelwise.py
new file mode 100644
index 0000000000000000000000000000000000000000..c68d88235b5356f34e3a4baa3b46ce7b0b0a8e11
--- /dev/null
+++ b/lam/losses/pixelwise.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+__all__ = ['PixelLoss']
+
+
+class PixelLoss(nn.Module):
+ """
+ Pixel-wise loss between two images.
+ """
+
+ def __init__(self, option: str = 'mse'):
+ super().__init__()
+ self.loss_fn = self._build_from_option(option)
+
+ @staticmethod
+ def _build_from_option(option: str, reduction: str = 'none'):
+ if option == 'mse':
+ return nn.MSELoss(reduction=reduction)
+ elif option == 'l1':
+ return nn.L1Loss(reduction=reduction)
+ else:
+ raise NotImplementedError(f'Unknown pixel loss option: {option}')
+
+ @torch.compile
+ def forward(self, x, y, conf_sigma=None, only_sym_conf=False):
+ """
+ Assume images are channel first.
+
+ Args:
+ x: [N, M, C, H, W]
+ y: [N, M, C, H, W]
+
+ Returns:
+ Mean-reduced pixel loss across batch.
+ """
+ N, M, C, H, W = x.shape
+ x = rearrange(x, "n m c h w -> (n m) c h w")
+ y = rearrange(y, "n m c h w -> (n m) c h w")
+ image_loss = self.loss_fn(x, y)
+
+ image_loss = image_loss.mean(dim=[1, 2, 3])
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
+ all_loss = batch_loss.mean()
+ return all_loss
diff --git a/lam/losses/tvloss.py b/lam/losses/tvloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..77a13b69b6f9fcacc38940373bf8159b3cf61459
--- /dev/null
+++ b/lam/losses/tvloss.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+__all__ = ['TVLoss']
+
+
+class TVLoss(nn.Module):
+ """
+ Total variance loss.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def numel_excluding_first_dim(self, x):
+ return x.numel() // x.shape[0]
+
+ @torch.compile
+ def forward(self, x):
+ """
+ Assume batched and channel first with inner sizes.
+
+ Args:
+ x: [N, M, C, H, W]
+
+ Returns:
+ Mean-reduced TV loss with element-level scaling.
+ """
+ N, M, C, H, W = x.shape
+ x = x.reshape(N*M, C, H, W)
+ diff_i = x[..., 1:, :] - x[..., :-1, :]
+ diff_j = x[..., :, 1:] - x[..., :, :-1]
+ div_i = self.numel_excluding_first_dim(diff_i)
+ div_j = self.numel_excluding_first_dim(diff_j)
+ tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
+ tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
+ tv = tv_i + tv_j
+ batch_tv = tv.reshape(N, M).mean(dim=1)
+ all_tv = batch_tv.mean()
+ return all_tv
diff --git a/lam/models/__init__.py b/lam/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cfbe21f611c508b85284f4b5ddf2340d0b2d5b5
--- /dev/null
+++ b/lam/models/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .modeling_lam import ModelLAM
+
+
+model_dict = {
+ 'lam': ModelLAM,
+}
diff --git a/lam/models/block.py b/lam/models/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..efaf23232362829fac07b2bb30daeca8176f9f9e
--- /dev/null
+++ b/lam/models/block.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch.nn as nn
+
+from .modulate import ModLN
+
+
+class BasicBlock(nn.Module):
+ """
+ Transformer block that is in its simplest form.
+ Designed for PF-LRM architecture.
+ """
+ # Block contains a self-attention layer and an MLP
+ def __init__(self, inner_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x):
+ # x: [N, L, D]
+ before_sa = self.norm1(x)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class ConditionBlock(nn.Module):
+ """
+ Transformer block that takes in a cross-attention condition.
+ Designed for SparseLRM architecture.
+ """
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x, cond):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond]
+ x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
+ before_sa = self.norm2(x)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm3(x))
+ return x
+
+
+class ConditionModulationBlock(nn.Module):
+ """
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
+ Designed for raw LRM architecture.
+ """
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
+ def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
+ super().__init__()
+ self.norm1 = ModLN(inner_dim, mod_dim, eps)
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm2 = ModLN(inner_dim, mod_dim, eps)
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm3 = ModLN(inner_dim, mod_dim, eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x, cond, mod):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond]
+ # mod: [N, D_mod]
+ x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
+ before_sa = self.norm2(x, mod)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm3(x, mod))
+ return x
diff --git a/lam/models/discriminator.py b/lam/models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..31412138e21ae6fd689ab494a1036abf88d71662
--- /dev/null
+++ b/lam/models/discriminator.py
@@ -0,0 +1,120 @@
+"""
+Ported from Paella
+"""
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+import functools
+# import torch.nn as nn
+from taming.modules.util import ActNorm
+
+
+# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
+class Discriminator(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
+ super().__init__()
+ d = max(depth - 3, 3)
+ layers = [
+ nn.utils.spectral_norm(
+ nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
+ ),
+ nn.LeakyReLU(0.2),
+ ]
+ for i in range(depth - 1):
+ c_in = hidden_channels // (2 ** max((d - i), 0))
+ c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.InstanceNorm2d(c_out))
+ layers.append(nn.LeakyReLU(0.2))
+ self.encoder = nn.Sequential(*layers)
+ self.shuffle = nn.Conv2d(
+ (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
+ )
+ # self.logits = nn.Sigmoid()
+
+
+ def forward(self, x, cond=None):
+ x = self.encoder(x)
+ if cond is not None:
+ cond = cond.view(
+ cond.size(0),
+ cond.size(1),
+ 1,
+ 1,
+ ).expand(-1, -1, x.size(-2), x.size(-1))
+ x = torch.cat([x, cond], dim=1)
+ x = self.shuffle(x)
+ # x = self.logits(x)
+ return x
+
+
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ # norm_layer = nn.BatchNorm2d
+ norm_layer = nn.InstanceNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ # use_bias = norm_layer.func != nn.BatchNorm2d
+ use_bias = norm_layer.func != nn.InstanceNorm2d
+ else:
+ # use_bias = norm_layer != nn.BatchNorm2d
+ use_bias = norm_layer != nn.InstanceNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, False)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, False)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/lam/models/encoders/__init__.py b/lam/models/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/lam/models/encoders/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/lam/models/encoders/dino_wrapper.py b/lam/models/encoders/dino_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb82225eb23c3c9b362a4f962c40addd18fbe5fc
--- /dev/null
+++ b/lam/models/encoders/dino_wrapper.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from transformers import ViTImageProcessor, ViTModel
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class DinoWrapper(nn.Module):
+ """
+ Dino v1 wrapper using huggingface transformer implementation.
+ """
+ def __init__(self, model_name: str, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.model, self.processor = self._build_dino(model_name)
+ if freeze:
+ self._freeze()
+
+ @torch.compile
+ def forward_model(self, inputs):
+ return self.model(**inputs, interpolate_pos_encoding=True)
+
+ def forward(self, image):
+ # image: [N, C, H, W], on cpu
+ # RGB image with [0,1] scale and properly sized
+ inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
+ # This resampling of positional embedding uses bicubic interpolation
+ outputs = self.forward_model(inputs)
+ last_hidden_states = outputs.last_hidden_state
+ return last_hidden_states
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing DinoWrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
+ import requests
+ try:
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
+ processor = ViTImageProcessor.from_pretrained(model_name)
+ return model, processor
+ except requests.exceptions.ProxyError as err:
+ if proxy_error_retries > 0:
+ print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
+ import time
+ time.sleep(proxy_error_cooldown)
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
+ else:
+ raise err
diff --git a/lam/models/encoders/dinov2/__init__.py b/lam/models/encoders/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/lam/models/encoders/dinov2/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/lam/models/encoders/dinov2/hub/__init__.py b/lam/models/encoders/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/lam/models/encoders/dinov2/hub/backbones.py b/lam/models/encoders/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd8c4010204da1f1e413db66d24a87e2a39a358
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/backbones.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ # ********** Modified by Zexin He in 2023-2024 **********
+ state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
+ if vit_kwargs.get("modulation_dim") is not None:
+ state_dict = {
+ k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
+ for k, v in state_dict.items()
+ }
+ model.load_state_dict(state_dict, strict=False)
+ else:
+ model.load_state_dict(state_dict, strict=True)
+ # ********************************************************
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/lam/models/encoders/dinov2/hub/classifiers.py b/lam/models/encoders/dinov2/hub/classifiers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/classifiers.py
@@ -0,0 +1,268 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from .backbones import _make_dinov2_model
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ IMAGENET1K = "IMAGENET1K"
+
+
+def _make_dinov2_linear_classification_head(
+ *,
+ arch_name: str = "vit_large",
+ patch_size: int = 14,
+ embed_dim: int = 1024,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
+
+ if pretrained:
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ layers_str = str(layers) if layers == 4 else ""
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ linear_head.load_state_dict(state_dict, strict=True)
+
+ return linear_head
+
+
+class _LinearClassifierWrapper(nn.Module):
+ def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
+ super().__init__()
+ self.backbone = backbone
+ self.linear_head = linear_head
+ self.layers = layers
+
+ def forward(self, x):
+ if self.layers == 1:
+ x = self.backbone.forward_features(x)
+ cls_token = x["x_norm_clstoken"]
+ patch_tokens = x["x_norm_patchtokens"]
+ # fmt: off
+ linear_input = torch.cat([
+ cls_token,
+ patch_tokens.mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ elif self.layers == 4:
+ x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
+ # fmt: off
+ linear_input = torch.cat([
+ x[0][1],
+ x[1][1],
+ x[2][1],
+ x[3][1],
+ x[3][0].mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ else:
+ assert False, f"Unsupported number of layers: {self.layers}"
+ return self.linear_head(linear_input)
+
+
+def _make_dinov2_linear_classifier(
+ *,
+ arch_name: str = "vit_large",
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ **kwargs,
+):
+ backbone = _make_dinov2_model(
+ arch_name=arch_name,
+ pretrained=pretrained,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ **kwargs,
+ )
+
+ embed_dim = backbone.embed_dim
+ patch_size = backbone.patch_size
+ linear_head = _make_dinov2_linear_classification_head(
+ arch_name=arch_name,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=num_register_tokens,
+ )
+
+ return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
+
+
+def dinov2_vits14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/lam/models/encoders/dinov2/hub/depth/__init__.py b/lam/models/encoders/dinov2/hub/depth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/depth/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .decode_heads import BNHead, DPTHead
+from .encoder_decoder import DepthEncoderDecoder
diff --git a/lam/models/encoders/dinov2/hub/depth/decode_heads.py b/lam/models/encoders/dinov2/hub/depth/decode_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/depth/decode_heads.py
@@ -0,0 +1,747 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import copy
+from functools import partial
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+
+from .ops import resize
+
+
+# XXX: (Untested) replacement for mmcv.imdenormalize()
+def _imdenormalize(img, mean, std, to_bgr=True):
+ import numpy as np
+
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = (img * std) + mean
+ if to_bgr:
+ img = img[::-1]
+ return img
+
+
+class DepthBaseDecodeHead(nn.Module):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (List): Input channels.
+ channels (int): Channels after modules, before conv_depth.
+ conv_layer (nn.Module): Conv layers. Default: None.
+ act_layer (nn.Module): Activation layers. Default: nn.ReLU.
+ loss_decode (dict): Config of decode loss.
+ Default: ().
+ sampler (dict|None): The config of depth map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ min_depth (int): Min depth in dataset setting.
+ Default: 1e-3.
+ max_depth (int): Max depth in dataset setting.
+ Default: None.
+ norm_layer (dict|None): Norm layers.
+ Default: None.
+ classify (bool): Whether predict depth in a cls.-reg. manner.
+ Default: False.
+ n_bins (int): The number of bins used in cls. step.
+ Default: 256.
+ bins_strategy (str): The discrete strategy used in cls. step.
+ Default: 'UD'.
+ norm_strategy (str): The norm strategy on cls. probability
+ distribution. Default: 'linear'
+ scale_up (str): Whether predict depth in a scale-up manner.
+ Default: False.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ conv_layer=None,
+ act_layer=nn.ReLU,
+ channels=96,
+ loss_decode=(),
+ sampler=None,
+ align_corners=False,
+ min_depth=1e-3,
+ max_depth=None,
+ norm_layer=None,
+ classify=False,
+ n_bins=256,
+ bins_strategy="UD",
+ norm_strategy="linear",
+ scale_up=False,
+ ):
+ super(DepthBaseDecodeHead, self).__init__()
+
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conf_layer = conv_layer
+ self.act_layer = act_layer
+ self.loss_decode = loss_decode
+ self.align_corners = align_corners
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.norm_layer = norm_layer
+ self.classify = classify
+ self.n_bins = n_bins
+ self.scale_up = scale_up
+
+ if self.classify:
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
+
+ self.bins_strategy = bins_strategy
+ self.norm_strategy = norm_strategy
+ self.softmax = nn.Softmax(dim=1)
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
+
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, inputs, img_metas):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, img, inputs, img_metas, depth_gt):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): GT depth
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ depth_pred = self.forward(inputs, img_metas)
+ losses = self.losses(depth_pred, depth_gt)
+
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
+ losses.update(**log_imgs)
+
+ return losses
+
+ def forward_test(self, inputs, img_metas):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+
+ Returns:
+ Tensor: Output depth map.
+ """
+ return self.forward(inputs, img_metas)
+
+ def depth_pred(self, feat):
+ """Prediction each pixel."""
+ if self.classify:
+ logit = self.conv_depth(feat)
+
+ if self.bins_strategy == "UD":
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+ elif self.bins_strategy == "SID":
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+
+ # following Adabins, default linear
+ if self.norm_strategy == "linear":
+ logit = torch.relu(logit)
+ eps = 0.1
+ logit = logit + eps
+ logit = logit / logit.sum(dim=1, keepdim=True)
+ elif self.norm_strategy == "softmax":
+ logit = torch.softmax(logit, dim=1)
+ elif self.norm_strategy == "sigmoid":
+ logit = torch.sigmoid(logit)
+ logit = logit / logit.sum(dim=1, keepdim=True)
+
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
+
+ else:
+ if self.scale_up:
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
+ else:
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
+ return output
+
+ def losses(self, depth_pred, depth_gt):
+ """Compute depth loss."""
+ loss = dict()
+ depth_pred = resize(
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
+ )
+ if not isinstance(self.loss_decode, nn.ModuleList):
+ losses_decode = [self.loss_decode]
+ else:
+ losses_decode = self.loss_decode
+ for loss_decode in losses_decode:
+ if loss_decode.loss_name not in loss:
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
+ else:
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
+ return loss
+
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
+ import numpy as np
+
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
+ show_img = show_img.numpy().astype(np.float32)
+ show_img = _imdenormalize(
+ show_img,
+ img_meta["img_norm_cfg"]["mean"],
+ img_meta["img_norm_cfg"]["std"],
+ img_meta["img_norm_cfg"]["to_rgb"],
+ )
+ show_img = np.clip(show_img, 0, 255)
+ show_img = show_img.astype(np.uint8)
+ show_img = show_img[:, :, ::-1]
+ show_img = show_img.transpose(0, 2, 1)
+ show_img = show_img.transpose(1, 0, 2)
+
+ depth_pred = depth_pred / torch.max(depth_pred)
+ depth_gt = depth_gt / torch.max(depth_gt)
+
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
+
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
+
+
+class BNHead(DepthBaseDecodeHead):
+ """Just a batchnorm."""
+
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
+ super().__init__(**kwargs)
+ self.input_transform = input_transform
+ self.in_index = in_index
+ self.upsample = upsample
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
+ if self.classify:
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if "concat" in self.input_transform:
+ inputs = [inputs[i] for i in self.in_index]
+ if "resize" in self.input_transform:
+ inputs = [
+ resize(
+ input=x,
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
+ mode="bilinear",
+ align_corners=self.align_corners,
+ )
+ for x in inputs
+ ]
+ inputs = torch.cat(inputs, dim=1)
+ elif self.input_transform == "multiple_select":
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
+ """Forward function for feature maps before classifying each pixel with
+ ``self.cls_seg`` fc.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
+ H, W) which is feature map for last layer of decoder head.
+ """
+ # accept lists (for cls token)
+ inputs = list(inputs)
+ for i, x in enumerate(inputs):
+ if len(x) == 2:
+ x, cls_token = x[0], x[1]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ cls_token = cls_token[:, :, None, None].expand_as(x)
+ inputs[i] = torch.cat((x, cls_token), 1)
+ else:
+ x = x[0]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ inputs[i] = x
+ x = self._transform_inputs(inputs)
+ # feats = self.bn(x)
+ return x
+
+ def forward(self, inputs, img_metas=None, **kwargs):
+ """Forward function."""
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
+ output = self.depth_pred(output)
+ return output
+
+
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
+ False. Default: "auto".
+ conv_layer (nn.Module): Convolution layer. Default: None,
+ which means using conv2d.
+ norm_layer (nn.Module): Normalization layer. Default: None.
+ act_layer (nn.Module): Activation layer. Default: nn.ReLU.
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = "conv_block"
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias="auto",
+ conv_layer=nn.Conv2d,
+ norm_layer=None,
+ act_layer=nn.ReLU,
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode="zeros",
+ order=("conv", "norm", "act"),
+ ):
+ super(ConvModule, self).__init__()
+ official_padding_mode = ["zeros", "circular"]
+ self.conv_layer = conv_layer
+ self.norm_layer = norm_layer
+ self.act_layer = act_layer
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(["conv", "norm", "act"])
+
+ self.with_norm = norm_layer is not None
+ self.with_activation = act_layer is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == "auto":
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ if padding_mode == "zeros":
+ padding_layer = nn.ZeroPad2d
+ else:
+ raise AssertionError(f"Unsupported padding mode: {padding_mode}")
+ self.pad = padding_layer(padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = self.conv_layer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index("norm") > order.index("conv"):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ norm = partial(norm_layer, num_features=norm_channels)
+ self.add_module("norm", norm)
+ if self.with_bias:
+ from torch.nnModules.batchnorm import _BatchNorm
+ from torch.nnModules.instancenorm import _InstanceNorm
+
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn("Unnecessary conv bias before batch/instance norm")
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ # nn.Tanh has no 'inplace' argument
+ # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
+ if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
+ act_layer = partial(act_layer, inplace=inplace)
+ self.activate = act_layer()
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, "init_weights"):
+ if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
+ nonlinearity = "leaky_relu"
+ a = 0.01 # XXX: default negative_slope
+ else:
+ nonlinearity = "relu"
+ a = 0
+ if hasattr(self.conv, "weight") and self.conv.weight is not None:
+ nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
+ if hasattr(self.conv, "bias") and self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+ if self.with_norm:
+ if hasattr(self.norm, "weight") and self.norm.weight is not None:
+ nn.init.constant_(self.norm.weight, 1)
+ if hasattr(self.norm, "bias") and self.norm.bias is not None:
+ nn.init.constant_(self.norm.bias, 0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == "conv":
+ if self.with_explicit_padding:
+ x = self.pad(x)
+ x = self.conv(x)
+ elif layer == "norm" and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == "act" and activate and self.with_activation:
+ x = self.activate(x)
+ return x
+
+
+class Interpolate(nn.Module):
+ def __init__(self, scale_factor, mode, align_corners=False):
+ super(Interpolate, self).__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+ return x
+
+
+class HeadDepth(nn.Module):
+ def __init__(self, features):
+ super(HeadDepth, self).__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(self, x):
+ x = self.head(x)
+ return x
+
+
+class ReassembleBlocks(nn.Module):
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
+ rearrange the feature vector to feature map.
+ Args:
+ in_channels (int): ViT feature channels. Default: 768.
+ out_channels (List): output channels of each stage.
+ Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ """
+
+ def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
+ super(ReassembleBlocks, self).__init__()
+
+ assert readout_type in ["ignore", "add", "project"]
+ self.readout_type = readout_type
+ self.patch_size = patch_size
+
+ self.projects = nn.ModuleList(
+ [
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ act_layer=None,
+ )
+ for out_channel in out_channels
+ ]
+ )
+
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+ if self.readout_type == "project":
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
+
+ def forward(self, inputs):
+ assert isinstance(inputs, list)
+ out = []
+ for i, x in enumerate(inputs):
+ assert len(x) == 2
+ x, cls_token = x[0], x[1]
+ feature_shape = x.shape
+ if self.readout_type == "project":
+ x = x.flatten(2).permute((0, 2, 1))
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ x = x.permute(0, 2, 1).reshape(feature_shape)
+ elif self.readout_type == "add":
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
+ x = x.reshape(feature_shape)
+ else:
+ pass
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+ out.append(x)
+ return out
+
+
+class PreActResidualConvUnit(nn.Module):
+ """ResidualConvUnit, pre-activate residual unit.
+ Args:
+ in_channels (int): number of channels in the input feature map.
+ act_layer (nn.Module): activation layer.
+ norm_layer (nn.Module): norm layer.
+ stride (int): stride of the first block. Default: 1
+ dilation (int): dilation rate for convs layers. Default: 1.
+ """
+
+ def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
+ super(PreActResidualConvUnit, self).__init__()
+
+ self.conv1 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ self.conv2 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ padding=1,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ def forward(self, inputs):
+ inputs_ = inputs.clone()
+ x = self.conv1(inputs)
+ x = self.conv2(x)
+ return x + inputs_
+
+
+class FeatureFusionBlock(nn.Module):
+ """FeatureFusionBlock, merge feature map from different stages.
+ Args:
+ in_channels (int): Input channels.
+ act_layer (nn.Module): activation layer for ResidualConvUnit.
+ norm_layer (nn.Module): normalization layer.
+ expand (bool): Whether expand the channels in post process block.
+ Default: False.
+ align_corners (bool): align_corner setting for bilinear upsample.
+ Default: True.
+ """
+
+ def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
+ super(FeatureFusionBlock, self).__init__()
+
+ self.in_channels = in_channels
+ self.expand = expand
+ self.align_corners = align_corners
+
+ self.out_channels = in_channels
+ if self.expand:
+ self.out_channels = in_channels // 2
+
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
+
+ self.res_conv_unit1 = PreActResidualConvUnit(
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+ )
+ self.res_conv_unit2 = PreActResidualConvUnit(
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+ )
+
+ def forward(self, *inputs):
+ x = inputs[0]
+ if len(inputs) == 2:
+ if x.shape != inputs[1].shape:
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
+ else:
+ res = inputs[1]
+ x = x + self.res_conv_unit1(res)
+ x = self.res_conv_unit2(x)
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
+ x = self.project(x)
+ return x
+
+
+class DPTHead(DepthBaseDecodeHead):
+ """Vision Transformers for Dense Prediction.
+ This head is implemented of `DPT `_.
+ Args:
+ embed_dims (int): The embed dimension of the ViT backbone.
+ Default: 768.
+ post_process_channels (List): Out channels of post process conv
+ layers. Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ expand_channels (bool): Whether expand the channels in post process
+ block. Default: False.
+ """
+
+ def __init__(
+ self,
+ embed_dims=768,
+ post_process_channels=[96, 192, 384, 768],
+ readout_type="ignore",
+ patch_size=16,
+ expand_channels=False,
+ **kwargs,
+ ):
+ super(DPTHead, self).__init__(**kwargs)
+
+ self.in_channels = self.in_channels
+ self.expand_channels = expand_channels
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
+
+ self.post_process_channels = [
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
+ ]
+ self.convs = nn.ModuleList()
+ for channel in self.post_process_channels:
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
+ self.fusion_blocks = nn.ModuleList()
+ for _ in range(len(self.convs)):
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
+ self.fusion_blocks[0].res_conv_unit1 = None
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
+ self.num_fusion_blocks = len(self.fusion_blocks)
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
+ self.num_post_process_channels = len(self.post_process_channels)
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
+ assert self.num_reassemble_blocks == self.num_post_process_channels
+ self.conv_depth = HeadDepth(self.channels)
+
+ def forward(self, inputs, img_metas):
+ assert len(inputs) == self.num_reassemble_blocks
+ x = [inp for inp in inputs]
+ x = self.reassemble_blocks(x)
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
+ out = self.fusion_blocks[0](x[-1])
+ for i in range(1, len(self.fusion_blocks)):
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
+ out = self.project(out)
+ out = self.depth_pred(out)
+ return out
diff --git a/lam/models/encoders/dinov2/hub/depth/encoder_decoder.py b/lam/models/encoders/dinov2/hub/depth/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/depth/encoder_decoder.py
@@ -0,0 +1,351 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ops import resize
+
+
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f"{prefix}.{name}"] = value
+
+ return outputs
+
+
+class DepthEncoderDecoder(nn.Module):
+ """Encoder Decoder depther.
+
+ EncoderDecoder typically consists of backbone and decode_head.
+ """
+
+ def __init__(self, backbone, decode_head):
+ super(DepthEncoderDecoder, self).__init__()
+
+ self.backbone = backbone
+ self.decode_head = decode_head
+ self.align_corners = self.decode_head.align_corners
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ return self.backbone(img)
+
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
+ """Encode images with backbone and decode into a depth estimation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ # crop the pred depth to the certain range.
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
+ if rescale:
+ if size is None:
+ if img_metas is not None:
+ size = img_metas[0]["ori_shape"][:2]
+ else:
+ size = img.shape[2:]
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
+ losses.update(add_prefix(loss_decode, "decode"))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ depth_pred = self.decode_head.forward_test(x, img_metas)
+ return depth_pred
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ depth = self.encode_decode(img, None)
+
+ return depth
+
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): Depth gt
+ used if the architecture supports depth estimation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ # the last of x saves the info from neck
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
+
+ losses.update(loss_decode)
+
+ return losses
+
+ def whole_inference(self, img, img_meta, rescale, size=None):
+ """Inference with full image."""
+ return self.encode_decode(img, img_meta, rescale, size=size)
+
+ def slide_inference(self, img, img_meta, rescale, stride, crop_size):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = stride
+ h_crop, w_crop = crop_size
+ batch_size, _, h_img, w_img = img.size()
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ return preds
+
+ def inference(self, img, img_meta, rescale, size=None, mode="whole"):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output depth map.
+ """
+
+ assert mode in ["slide", "whole"]
+ ori_shape = img_meta[0]["ori_shape"]
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
+ if mode == "slide":
+ depth_pred = self.slide_inference(img, img_meta, rescale)
+ else:
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
+ output = depth_pred
+ flip = img_meta[0]["flip"]
+ if flip:
+ flip_direction = img_meta[0]["flip_direction"]
+ assert flip_direction in ["horizontal", "vertical"]
+ if flip_direction == "horizontal":
+ output = output.flip(dims=(3,))
+ elif flip_direction == "vertical":
+ output = output.flip(dims=(2,))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ depth_pred = self.inference(img, img_meta, rescale)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ depth_pred = depth_pred.unsqueeze(0)
+ return depth_pred
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented depth logit inplace
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
+ depth_pred += cur_depth_pred
+ depth_pred /= len(imgs)
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
+ if not isinstance(var, list):
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_["img_shape"] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+
+ # split losses and images
+ real_losses = {}
+ log_imgs = {}
+ for k, v in losses.items():
+ if "img" in k:
+ log_imgs[k] = v
+ else:
+ real_losses[k] = v
+
+ loss, log_vars = self._parse_losses(real_losses)
+
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ import torch.distributed as dist
+
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
+
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
+
+ log_vars["loss"] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
diff --git a/lam/models/encoders/dinov2/hub/depth/ops.py b/lam/models/encoders/dinov2/hub/depth/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/depth/ops.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch.nn.functional as F
+
+
+def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if (
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
+ and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)
+ ):
+ warnings.warn(
+ f"When align_corners={align_corners}, "
+ "the output would more aligned if "
+ f"input size {(input_h, input_w)} is `x+1` and "
+ f"out size {(output_h, output_w)} is `nx+1`"
+ )
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/lam/models/encoders/dinov2/hub/depthers.py b/lam/models/encoders/dinov2/hub/depthers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/depthers.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import torch
+
+from .backbones import _make_dinov2_model
+from .depth import BNHead, DepthEncoderDecoder, DPTHead
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
+
+
+class Weights(Enum):
+ NYU = "NYU"
+ KITTI = "KITTI"
+
+
+def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
+ if not pretrained: # Default
+ return (0.001, 10.0)
+
+ # Pretrained, set according to the training dataset for the provided weights
+ if weights == Weights.KITTI:
+ return (0.001, 80.0)
+
+ if weights == Weights.NYU:
+ return (0.001, 10.0)
+
+ return (0.001, 10.0)
+
+
+def _make_dinov2_linear_depth_head(
+ *,
+ embed_dim: int,
+ layers: int,
+ min_depth: float,
+ max_depth: float,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+
+ if layers == 1:
+ in_index = [0]
+ else:
+ assert layers == 4
+ in_index = [0, 1, 2, 3]
+
+ return BNHead(
+ classify=True,
+ n_bins=256,
+ bins_strategy="UD",
+ norm_strategy="linear",
+ upsample=4,
+ in_channels=[embed_dim] * len(in_index),
+ in_index=in_index,
+ input_transform="resize_concat",
+ channels=embed_dim * len(in_index) * 2,
+ align_corners=False,
+ min_depth=0.001,
+ max_depth=80,
+ loss_decode=(),
+ )
+
+
+def _make_dinov2_linear_depther(
+ *,
+ arch_name: str = "vit_large",
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.NYU,
+ depth_range: Optional[Tuple[float, float]] = None,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ if depth_range is None:
+ depth_range = _get_depth_range(pretrained, weights)
+ min_depth, max_depth = depth_range
+
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+ embed_dim = backbone.embed_dim
+ patch_size = backbone.patch_size
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
+ linear_depth_head = _make_dinov2_linear_depth_head(
+ embed_dim=embed_dim,
+ layers=layers,
+ min_depth=min_depth,
+ max_depth=max_depth,
+ )
+
+ layer_count = {
+ "vit_small": 12,
+ "vit_base": 12,
+ "vit_large": 24,
+ "vit_giant2": 40,
+ }[arch_name]
+
+ if layers == 4:
+ out_index = {
+ "vit_small": [2, 5, 8, 11],
+ "vit_base": [2, 5, 8, 11],
+ "vit_large": [4, 11, 17, 23],
+ "vit_giant2": [9, 19, 29, 39],
+ }[arch_name]
+ else:
+ assert layers == 1
+ out_index = [layer_count - 1]
+
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
+ model.backbone.forward = partial(
+ backbone.get_intermediate_layers,
+ n=out_index,
+ reshape=True,
+ return_class_token=True,
+ norm=False,
+ )
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
+
+ if pretrained:
+ layers_str = str(layers) if layers == 4 else ""
+ weights_str = weights.value.lower()
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
+ return DPTHead(
+ in_channels=[embed_dim] * 4,
+ channels=256,
+ embed_dims=embed_dim,
+ post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
+ readout_type="project",
+ min_depth=min_depth,
+ max_depth=max_depth,
+ loss_decode=(),
+ )
+
+
+def _make_dinov2_dpt_depther(
+ *,
+ arch_name: str = "vit_large",
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.NYU,
+ depth_range: Optional[Tuple[float, float]] = None,
+ **kwargs,
+):
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ if depth_range is None:
+ depth_range = _get_depth_range(pretrained, weights)
+ min_depth, max_depth = depth_range
+
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+ model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
+ dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
+
+ out_index = {
+ "vit_small": [2, 5, 8, 11],
+ "vit_base": [2, 5, 8, 11],
+ "vit_large": [4, 11, 17, 23],
+ "vit_giant2": [9, 19, 29, 39],
+ }[arch_name]
+
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
+ model.backbone.forward = partial(
+ backbone.get_intermediate_layers,
+ n=out_index,
+ reshape=True,
+ return_class_token=True,
+ norm=False,
+ )
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
+
+ if pretrained:
+ weights_str = weights.value.lower()
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(
+ arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+ )
diff --git a/lam/models/encoders/dinov2/hub/utils.py b/lam/models/encoders/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/lam/models/encoders/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/lam/models/encoders/dinov2/layers/__init__.py b/lam/models/encoders/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..77967aa6ccfae24c39b8e167c83dd77073fd68fb
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+# ********** Modified by Zexin He in 2023-2024 **********
+# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
+from .block import Block, BlockWithModulation
+# ********************************************************
+from .attention import MemEffAttention
diff --git a/lam/models/encoders/dinov2/layers/attention.py b/lam/models/encoders/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Attention)")
+ else:
+ warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/lam/models/encoders/dinov2/layers/block.py b/lam/models/encoders/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5b50118c1579fd30cda0c2d60b95c85eb04204
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/block.py
@@ -0,0 +1,296 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Block)")
+ else:
+ warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# Override forward with modulation input
+class BlockWithModulation(Block):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: Tensor, mod: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x, mod)))
+
+ def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x, mod)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet")
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, mod))
+ x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, mod)
+ x = x + ffn_residual_func(x, mod)
+ return x
+# ********************************************************
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning)
+ # ********************************************************
+
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/lam/models/encoders/dinov2/layers/dino_head.py b/lam/models/encoders/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/lam/models/encoders/dinov2/layers/drop_path.py b/lam/models/encoders/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/lam/models/encoders/dinov2/layers/layer_scale.py b/lam/models/encoders/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/lam/models/encoders/dinov2/layers/mlp.py b/lam/models/encoders/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/lam/models/encoders/dinov2/layers/patch_embed.py b/lam/models/encoders/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/lam/models/encoders/dinov2/layers/swiglu_ffn.py b/lam/models/encoders/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74
--- /dev/null
+++ b/lam/models/encoders/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/lam/models/encoders/dinov2/models/__init__.py b/lam/models/encoders/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/lam/models/encoders/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/lam/models/encoders/dinov2/models/vision_transformer.py b/lam/models/encoders/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c90ac2be1fe294a0db6080cd24155629083d3ec9
--- /dev/null
+++ b/lam/models/encoders/dinov2/models/vision_transformer.py
@@ -0,0 +1,443 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation
+# ********************************************************
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ # ********** Modified by Zexin He in 2023-2024 **********
+ modulation_dim: int = None,
+ # ********************************************************
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ block_norm_layer = None
+ if modulation_dim is not None:
+ from ....modulate import ModLN
+ block_norm_layer = partial(ModLN, mod_dim=modulation_dim)
+ else:
+ block_norm_layer = nn.LayerNorm
+ block_norm_layer = partial(block_norm_layer, eps=1e-6)
+ # ********************************************************
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ # ********** Modified by Zexin He in 2023-2024 **********
+ norm_layer=block_norm_layer,
+ # ********************************************************
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ # hacking unused mask_token for better DDP
+ # self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+ # ********************************************************
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ # ********** Modified by Zexin He in 2023-2024 **********
+ raise NotImplementedError("Masking is not supported in hacked DINOv2")
+ # x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+ # ********************************************************
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ def forward_features(self, x, masks=None, mod=None):
+ if isinstance(x, list):
+ raise DeprecationWarning("forward_features_list is deprecated, use forward_features")
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ if mod is None:
+ for blk in self.blocks:
+ x = blk(x)
+ else:
+ for blk in self.blocks:
+ x = blk(x, mod)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ # ********************************************************
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# block class selected from Block and BlockWithModulation
+
+def _block_cls(**kwargs):
+ modulation_dim = kwargs.get("modulation_dim", None)
+ if modulation_dim is None:
+ block_cls = Block
+ else:
+ block_cls = BlockWithModulation
+ return block_cls
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+# ********************************************************
diff --git a/lam/models/encoders/dinov2_dpt.py b/lam/models/encoders/dinov2_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..194d25e52fe57f1c6277b2d40e0f4bed9032ef5b
--- /dev/null
+++ b/lam/models/encoders/dinov2_dpt.py
@@ -0,0 +1,252 @@
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+from torchvision.transforms import Compose
+
+# from lam.models.encoders.dpt_util.dinov2 import DINOv2
+from lam.models.encoders.dpt_util.blocks import FeatureFusionBlock, _make_scratch
+from lam.models.encoders.dpt_util.transform import Resize, NormalizeImage, PrepareForNet
+
+
+def _make_fusion_block(features, use_bn, size=None, use_conv1=True):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ use_conv1=use_conv1,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False,
+ out_channel=384,
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ # self.resize_layers = nn.ModuleList([
+ # nn.ConvTranspose2d(
+ # in_channels=out_channels[0],
+ # out_channels=out_channels[0],
+ # kernel_size=4,
+ # stride=4,
+ # padding=0),
+ # nn.ConvTranspose2d(
+ # in_channels=out_channels[1],
+ # out_channels=out_channels[1],
+ # kernel_size=2,
+ # stride=2,
+ # padding=0),
+ # nn.Identity(),
+ # nn.Conv2d(
+ # in_channels=out_channels[3],
+ # out_channels=out_channels[3],
+ # kernel_size=3,
+ # stride=2,
+ # padding=1)
+ # ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn, use_conv1=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ # self.scratch.output_conv1 = nn.Conv2d(head_features_1, out_channnels, kernel_size=3, stride=1, padding=1)
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, out_channel, kernel_size=1, stride=1, padding=0)
+
+ # self.scratch.output_conv2 = nn.Sequential(
+ # nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ # nn.ReLU(True),
+ # nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ # nn.ReLU(True),
+ # nn.Identity(),
+ # )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ # x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:], scale_factor=1)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:], scale_factor=1)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:], scale_factor=1)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn, scale_factor=1)
+
+ # path_4 = self.scratch.refinenet4(layer_1_rn, size=layer_2_rn.shape[2:], scale_factor=1)
+ # path_3 = self.scratch.refinenet3(path_4, layer_2_rn, size=layer_3_rn.shape[2:], scale_factor=1)
+ # path_2 = self.scratch.refinenet2(path_3, layer_3_rn, size=layer_4_rn.shape[2:], scale_factor=1)
+ # path_1 = self.scratch.refinenet1(path_2, layer_4_rn, scale_factor=1)
+
+ out = self.scratch.output_conv1(path_1)
+ # out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ # out = self.scratch.output_conv2(out)
+
+ return out
+
+
+class DINODPT(nn.Module):
+ def __init__(
+ self,
+ model_name="vitb",
+ out_dim=384,
+ use_bn=False,
+ use_clstoken=False
+ ):
+ super(DINODPT, self).__init__()
+
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+ }
+
+ encoder = model_configs[model_name]["encoder"]
+ features = model_configs[model_name]["features"]
+ out_channels = model_configs[model_name]["out_channels"]
+
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitb': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23],
+ 'vitg': [9, 19, 29, 39]
+ }
+
+ self.encoder = encoder
+
+ # self.dino_model = DINOv2(model_name=encoder)
+ self.dino_model = torch.hub.load('facebookresearch/dinov2', f'dinov2_{encoder}14', pretrained=True)
+ self.dense_head = DPTHead(self.dino_model.embed_dim, features, use_bn, out_channels=out_channels,
+ use_clstoken=use_clstoken, out_channel=out_dim)
+
+ self.dino_normlize = torchvision.transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ def forward(self, x, is_training=True):
+ x = self.dino_normlize(x)
+
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
+
+ features = self.dino_model.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
+
+ feat = self.dense_head(features, patch_h, patch_w)
+ # print(x.shape, feat.shape)
+ # depth = F.relu(depth)
+ # return depth.squeeze(1)
+ out_global = None
+ return feat, out_global
+
+ @torch.no_grad()
+ def infer_image(self, raw_image, input_size=518):
+ image, (h, w) = self.image2tensor(raw_image, input_size)
+
+ depth = self.forward(image)
+
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
+
+ return depth.cpu().numpy()
+
+ def image2tensor(self, raw_image, input_size=518):
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ h, w = raw_image.shape[:2]
+
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
+
+ image = transform({'image': image})['image']
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+ image = image.to(DEVICE)
+
+ return image, (h, w)
diff --git a/lam/models/encoders/dinov2_dpt_wrapper.py b/lam/models/encoders/dinov2_dpt_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0629bee031a2e3a3abfa9a1778d6fd8b241303c
--- /dev/null
+++ b/lam/models/encoders/dinov2_dpt_wrapper.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+from lam.models.encoders.dinov2_dpt import DINODPT
+
+logger = get_logger(__name__)
+
+
+class Dinov2DPTWrapper(nn.Module):
+ """
+ Dinov2DPTWrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ # self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+ # self.model = DINOBase(output_dim=384)
+ self.model = DINODPT(model_name="vitb", out_dim=encoder_feat_dim)
+
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+ else:
+ for name, param in self.model.dino_model.named_parameters():
+ if name == "mask_token":
+ param.requires_grad = False
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing Dinov2DPTWrapper ========")
+ self.model.dino_model.eval()
+ for name, param in self.model.dino_model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ if self.modulation_dim is None:
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
+ outs = self.model(image, is_training=True)
+ else:
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
+ outs = self.model(image, mod=mod, is_training=True)
+
+ out_local, out_global = outs
+ if out_global is not None:
+ ret = torch.cat([out_local.permute(0, 2, 3, 1).flatten(1, 2), out_global.unsqueeze(1)], dim=1)
+ else:
+ ret = out_local.permute(0, 2, 3, 1).flatten(1, 2)
+ return ret
diff --git a/lam/models/encoders/dinov2_featup_wrapper.py b/lam/models/encoders/dinov2_featup_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c821cb4d4722735effff2f5e353fa52790d19c0
--- /dev/null
+++ b/lam/models/encoders/dinov2_featup_wrapper.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Dinov2FeatUpWrapper(nn.Module):
+ """
+ Dinov2FeatUpWrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ self.model = torch.hub.load("mhamilton723/FeatUp", 'dinov2', use_norm=True)
+
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+ else:
+ for name, param in self.model.named_parameters():
+ if name == "model.0.model.mask_token":
+ param.requires_grad = False
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing Dinov2UnetWrapper ========")
+ self.model.model.eval()
+ for name, param in self.model.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ if self.modulation_dim is None:
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
+ outs = self.model(image)
+ else:
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
+ outs = self.model(image, mod=mod)
+ out_local = outs
+ out_local = nn.functional.avg_pool2d(out_local, stride=2, kernel_size=2)
+ ret = out_local.permute(0, 2, 3, 1).flatten(1, 2)
+ return ret
diff --git a/lam/models/encoders/dinov2_fusion_wrapper.py b/lam/models/encoders/dinov2_fusion_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac1cebde8de421237114f866c7fb63defbc89ca8
--- /dev/null
+++ b/lam/models/encoders/dinov2_fusion_wrapper.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ inner_channels,
+ use_clstoken=False,
+ out_channel=1024,
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in inner_channels
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.output_conv = nn.Conv2d(sum(inner_channels) , out_channel, kernel_size=1, stride=1, padding=0)
+
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+
+ out.append(x)
+
+ fusion_feats = torch.cat(out, dim=1)
+
+ fusion_feats = self.output_conv(fusion_feats)
+
+ return fusion_feats
+
+
+class Dinov2FusionWrapper(nn.Module):
+ """
+ Dinov2FusionWrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+
+ self.intermediate_layer_idx_info = {
+ 'dinov2_vits14_reg': [2, 5, 8, 11],
+ 'dinov2_vitb14_reg': [2, 5, 8, 11],
+ 'dinov2_vitl14_reg': [4, 11, 17, 23],
+ 'dinov2_vitg14_reg': [9, 19, 29, 39]
+ }
+
+ self.intermediate_layer_idx = self.intermediate_layer_idx_info[model_name]
+ self.fusion_head = DPTHead(in_channels=self.model.embed_dim,
+ inner_channels=[self.model.embed_dim] * 4,
+ out_channel=encoder_feat_dim)
+
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+
+
+ def _freeze(self):
+ # logger.warning(f"======== Freezing Dinov2FusionWrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ # logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+
+ patch_h, patch_w = image.shape[-2] // self.model.patch_size, image.shape[-1] // self.model.patch_size
+
+ features = self.model.get_intermediate_layers(image, self.intermediate_layer_idx, return_class_token=True)
+
+ out_local = self.fusion_head(features, patch_h, patch_w)
+
+ out_global = None
+ if out_global is not None:
+ ret = torch.cat([out_local.permute(0, 2, 3, 1).flatten(1, 2), out_global.unsqueeze(1)], dim=1)
+ else:
+ ret = out_local.permute(0, 2, 3, 1).flatten(1, 2)
+ return ret
diff --git a/lam/models/encoders/dinov2_unet.py b/lam/models/encoders/dinov2_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..07e0e6b63f79da1dbee21fb0eec99c68498f342a
--- /dev/null
+++ b/lam/models/encoders/dinov2_unet.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
+
+import torch
+import torchvision
+import torch.nn as nn
+import timm
+from accelerate.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+
+class DINOBase(nn.Module):
+ def __init__(self, output_dim=128, only_global=False):
+ super().__init__()
+ self.only_global = only_global
+ assert self.only_global == False
+ self.dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', pretrained=True)
+
+ # self.encoder = timm.create_model("resnet18", pretrained=True)
+ # del self.encoder.global_pool
+ # del self.encoder.fc
+
+ # model_name = "dinov2_vits14_reg"
+ # modulation_dim = None
+ # self.dino_model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+
+ self.dino_normlize = torchvision.transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ in_dim = self.dino_model.blocks[0].attn.qkv.in_features
+ hidden_dims=256
+ out_dims=[256, 512, 1024, 1024]
+ # modules
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_dim, out_dim, kernel_size=1, stride=1, padding=0,
+ ) for out_dim in out_dims
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(
+ out_dims[0], out_dims[0], kernel_size=3, stride=1, padding=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(
+ out_dims[0], out_dims[0], kernel_size=3, stride=1, padding=1)
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(
+ out_dims[1], out_dims[1], kernel_size=3, stride=1, padding=1)
+ ),
+ nn.Sequential(
+ nn.Conv2d(
+ out_dims[2], out_dims[2], kernel_size=3, stride=1, padding=1)
+ ),
+ nn.Sequential(
+ nn.Conv2d(
+ out_dims[3], out_dims[3], kernel_size=3, stride=2, padding=1)
+ )
+ ])
+ # self.layer_rn = nn.ModuleList([
+ # nn.Conv2d(out_dims[0]+64, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[1]+128, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[2]+256, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[3]+512, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # ])
+ self.layer_rn = nn.ModuleList([
+ nn.Conv2d(out_dims[0]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.Conv2d(out_dims[1]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.Conv2d(out_dims[2]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.Conv2d(out_dims[3]+3, hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ ])
+ # self.layer_rn = nn.ModuleList([
+ # nn.Conv2d(out_dims[0], hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[1], hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[2], hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # nn.Conv2d(out_dims[3], hidden_dims, kernel_size=3, stride=1, padding=1, bias=False),
+ # ])
+
+ self.refinenet = nn.ModuleList([
+ FeatureFusionBlock(hidden_dims, nn.ReLU(False), use_conv1=False),
+ FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
+ FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
+ FeatureFusionBlock(hidden_dims, nn.ReLU(False)),
+ ])
+ self.output_conv = nn.Conv2d(hidden_dims, output_dim, kernel_size=3, stride=1, padding=1)
+ # self.output_gloabl_proj = nn.Linear(384, output_dim)
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ def forward(self, images, output_size=None, is_training=True):
+ # enc_output = self.encoder.forward_intermediates(images, stop_early=True, intermediates_only=True)
+ # enc_out4 = enc_output[4] # 32
+ # enc_out3 = enc_output[3] # 16
+ # enc_out2 = enc_output[2] # 8
+ # enc_out1 = enc_output[1] # 4
+
+ images = self.dino_normlize(images)
+ patch_h, patch_w = images.shape[-2]//14, images.shape[-1]//14
+
+ image_features = self.dino_model.get_intermediate_layers(images, 4)
+
+ out_features = []
+ for i, feature in enumerate(image_features):
+ feature = feature.permute(0, 2, 1).reshape(
+ (feature.shape[0], feature.shape[-1], patch_h, patch_w)
+ )
+ feature = self.projects[i](feature)
+ feature = self.resize_layers[i](feature)
+ # print(enc_output[i+1].shape, feature.shape)
+ feature = torch.cat([
+ nn.functional.interpolate(images, (feature.shape[-2], feature.shape[-1]), mode="bilinear", align_corners=True),
+ feature
+ ], dim=1
+ )
+ out_features.append(feature)
+ layer_rns = []
+ for i, feature in enumerate(out_features):
+ layer_rns.append(self.layer_rn[i](feature))
+
+ path_4 = self.refinenet[0](layer_rns[3], size=layer_rns[2].shape[2:])
+ path_3 = self.refinenet[1](path_4, layer_rns[2], size=layer_rns[1].shape[2:])
+ path_2 = self.refinenet[2](path_3, layer_rns[1], size=layer_rns[0].shape[2:])
+ path_1 = self.refinenet[3](path_2, layer_rns[0])
+ out = self.output_conv(path_1)
+
+ if output_size is not None:
+ out = nn.functional.interpolate(out, output_size, mode="bilinear", align_corners=True)
+ # out_global = image_features[-1][:, 0]
+ # out_global = self.output_gloabl_proj(out_global)
+ out_global = None
+ return out, out_global
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+ # return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None,
+ use_conv1=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ if use_conv1:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output = output + res
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+ output = nn.functional.interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+ output = self.out_conv(output)
+ return output
diff --git a/lam/models/encoders/dinov2_unet_wrapper.py b/lam/models/encoders/dinov2_unet_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..382823b874ba33745235e8ceae30e3d4240a8e34
--- /dev/null
+++ b/lam/models/encoders/dinov2_unet_wrapper.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+from lam.models.encoders.dinov2_unet import DINOBase
+
+logger = get_logger(__name__)
+
+
+class Dinov2UnetWrapper(nn.Module):
+ """
+ Dino v2 wrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ # self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+ self.model = DINOBase(output_dim=encoder_feat_dim)
+ assert model_name in ["no_avg", "avg_2"]
+ self.model_name = model_name
+
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+ else:
+ for name, param in self.model.dino_model.named_parameters():
+ if name == "mask_token":
+ param.requires_grad = False
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing Dinov2UnetWrapper ========")
+ self.model.dino_model.eval()
+ for name, param in self.model.dino_model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ if self.modulation_dim is None:
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
+ outs = self.model(image, is_training=True)
+ else:
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
+ outs = self.model(image, mod=mod, is_training=True)
+
+ out_local, out_global = outs
+
+ if self.model_name == "avg_2":
+ out_local = nn.functional.avg_pool2d(out_local, stride=2, kernel_size=2)
+
+ if out_global is not None:
+ ret = torch.cat([out_local.permute(0, 2, 3, 1).flatten(1, 2), out_global.unsqueeze(1)], dim=1)
+ else:
+ ret = out_local.permute(0, 2, 3, 1).flatten(1, 2)
+ return ret
diff --git a/lam/models/encoders/dinov2_wrapper.py b/lam/models/encoders/dinov2_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8453ca68833e3cea5dd40055cfe869ee0ddf317a
--- /dev/null
+++ b/lam/models/encoders/dinov2_wrapper.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class Dinov2Wrapper(nn.Module):
+ """
+ Dino v2 wrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing Dinov2Wrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ if self.modulation_dim is None:
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
+ outs = self.model(image, is_training=True)
+ else:
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
+ outs = self.model(image, mod=mod, is_training=True)
+ ret = torch.cat([
+ outs["x_norm_clstoken"].unsqueeze(dim=1),
+ outs["x_norm_patchtokens"],
+ ], dim=1)
+ return ret
diff --git a/lam/models/encoders/dpt_util/__init__.py b/lam/models/encoders/dpt_util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lam/models/encoders/dpt_util/blocks.py b/lam/models/encoders/dpt_util/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e562a4b387171d5afdcfab5a524b85b32981d891
--- /dev/null
+++ b/lam/models/encoders/dpt_util/blocks.py
@@ -0,0 +1,151 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ use_conv1=True
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ if use_conv1:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+
+ self.size=size
+
+ def forward(self, *xs, size=None, scale_factor=2):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": scale_factor}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/lam/models/encoders/dpt_util/transform.py b/lam/models/encoders/dpt_util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73
--- /dev/null
+++ b/lam/models/encoders/dpt_util/transform.py
@@ -0,0 +1,158 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/lam/models/encoders/xunet_wrapper.py b/lam/models/encoders/xunet_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0759a2aa6e029d7e691214e5a0a7435e8cfc996
--- /dev/null
+++ b/lam/models/encoders/xunet_wrapper.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+import timm
+from accelerate.logging import get_logger
+
+logger = get_logger(__name__)
+
+class XUNet(nn.Module):
+ def __init__(self, model_name="swin_base_patch4_window12_384_in22k", encoder_feat_dim=384):
+ super(XUNet, self).__init__()
+ # Swin Transformer Encoder
+ self.encoder = timm.create_model(model_name, pretrained=True)
+ # swin
+ # del self.encoder.head
+ # del self.encoder.norm
+ # resnet
+ del self.encoder.global_pool
+ del self.encoder.fc
+
+ # Decoder layers
+ # self.upconv4 = self.upconv_block(2048, 1024) # Upsample
+ # self.upconv3 = self.upconv_block(1024, 512)
+ # self.upconv2 = self.upconv_block(512, 256)
+ # self.upconv1 = self.upconv_block(256, 64)
+
+ self.upconv4 = self.upconv_block(512, 256) # Upsample
+ self.upconv3 = self.upconv_block(256, 128)
+ self.upconv2 = self.upconv_block(128, 64)
+ # self.upconv1 = self.upconv_block(64, 64)
+
+ self.out_conv = nn.Conv2d(64, encoder_feat_dim, kernel_size=1)
+
+
+ def upconv_block(self, in_channels, out_channels):
+ return nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ # Encoder part using Swin Transformer
+ enc_output = self.encoder.forward_intermediates(x, stop_early=True, intermediates_only=True)
+
+ # for e in enc_output:
+ # print(e.shape, x.shape)
+
+ # Assuming output of the encoder is a list of feature maps
+ # Resize them according to UNet architecture
+ enc_out4 = enc_output[4] # Adjust according to the feature layers of Swin
+ enc_out3 = enc_output[3]
+ enc_out2 = enc_output[2]
+ enc_out1 = enc_output[1]
+ # enc_out0 = enc_output[0]
+
+ # Decoder part
+ x = self.upconv4(enc_out4)
+ x = x + enc_out3 # s16, Skip connection
+ x = self.upconv3(x)
+ x = x + enc_out2 # s8
+ x = self.upconv2(x)
+ x = x + enc_out1 # s4
+ # x = self.upconv1(x)
+ # x = x + enc_out0 # s2
+
+ x = self.out_conv(x)
+ return x
+
+
+class XnetWrapper(nn.Module):
+ """
+ XnetWrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ self.model = XUNet(model_name=model_name, encoder_feat_dim=encoder_feat_dim)
+
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated SwinUnetWrapper requires training, freezing is not allowed.")
+ self._freeze()
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing SwinUnetWrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ outs = self.model(image)
+ ret = outs.permute(0, 2, 3, 1).flatten(1, 2)
+ return ret
diff --git a/lam/models/modeling_lam.py b/lam/models/modeling_lam.py
new file mode 100644
index 0000000000000000000000000000000000000000..912bda7b4ae44f4dec680ddee73d7b3acbc55e6b
--- /dev/null
+++ b/lam/models/modeling_lam.py
@@ -0,0 +1,367 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import math
+from collections import defaultdict
+import numpy as np
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+from einops import rearrange, repeat
+
+from .transformer import TransformerDecoder
+from lam.models.rendering.gs_renderer import GS3DRenderer, PointEmbed
+from diffusers.utils import is_torch_version
+
+logger = get_logger(__name__)
+
+
+class ModelLAM(nn.Module):
+ """
+ Full model of the basic single-view large reconstruction model.
+ """
+ def __init__(self,
+ transformer_dim: int, transformer_layers: int, transformer_heads: int,
+ transformer_type="cond",
+ tf_grad_ckpt=False,
+ encoder_grad_ckpt=False,
+ encoder_freeze: bool = True, encoder_type: str = 'dino',
+ encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768,
+ num_pcl: int=2048, pcl_dim: int=512,
+ human_model_path=None,
+ flame_subdivide_num=2,
+ flame_type="flame",
+ gs_query_dim=None,
+ gs_use_rgb=False,
+ gs_sh=3,
+ gs_mlp_network_config=None,
+ gs_xyz_offset_max_step=1.8 / 32,
+ gs_clip_scaling=0.2,
+ shape_param_dim=100,
+ expr_param_dim=50,
+ fix_opacity=False,
+ fix_rotation=False,
+ flame_scale=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.gradient_checkpointing = tf_grad_ckpt
+ self.encoder_gradient_checkpointing = encoder_grad_ckpt
+
+ # attributes
+ self.encoder_feat_dim = encoder_feat_dim
+ self.conf_use_pred_img = False
+ self.conf_cat_feat = False and self.conf_use_pred_img # True # False
+
+ # modules
+ # image encoder
+ self.encoder = self._encoder_fn(encoder_type)(
+ model_name=encoder_model_name,
+ freeze=encoder_freeze,
+ encoder_feat_dim=encoder_feat_dim,
+ )
+
+ # learnable points embedding
+ skip_decoder = False
+ self.latent_query_points_type = kwargs.get("latent_query_points_type", "e2e_flame")
+ if self.latent_query_points_type == "embedding":
+ self.num_pcl = num_pcl
+ self.pcl_embeddings = nn.Embedding(num_pcl , pcl_dim)
+ elif self.latent_query_points_type.startswith("flame"):
+ latent_query_points_file = os.path.join(human_model_path, "flame_points", f"{self.latent_query_points_type}.npy")
+ pcl_embeddings = torch.from_numpy(np.load(latent_query_points_file)).float()
+ print(f"==========load flame points:{latent_query_points_file}, shape:{pcl_embeddings.shape}")
+ self.register_buffer("pcl_embeddings", pcl_embeddings)
+ self.pcl_embed = PointEmbed(dim=pcl_dim)
+ elif self.latent_query_points_type.startswith("e2e_flame"):
+ skip_decoder = True
+ self.pcl_embed = PointEmbed(dim=pcl_dim)
+ else:
+ raise NotImplementedError
+ print("==="*16*3, f"\nskip_decoder: {skip_decoder}", "\n"+"==="*16*3)
+ # transformer
+ self.transformer = TransformerDecoder(
+ block_type=transformer_type,
+ num_layers=transformer_layers, num_heads=transformer_heads,
+ inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
+
+ # renderer
+ self.renderer = GS3DRenderer(human_model_path=human_model_path,
+ subdivide_num=flame_subdivide_num,
+ smpl_type=flame_type,
+ feat_dim=transformer_dim,
+ query_dim=gs_query_dim,
+ use_rgb=gs_use_rgb,
+ sh_degree=gs_sh,
+ mlp_network_config=gs_mlp_network_config,
+ xyz_offset_max_step=gs_xyz_offset_max_step,
+ clip_scaling=gs_clip_scaling,
+ scale_sphere=kwargs.get("scale_sphere", False),
+ shape_param_dim=shape_param_dim,
+ expr_param_dim=expr_param_dim,
+ fix_opacity=fix_opacity,
+ fix_rotation=fix_rotation,
+ skip_decoder=skip_decoder,
+ decode_with_extra_info=kwargs.get("decode_with_extra_info", None),
+ gradient_checkpointing=self.gradient_checkpointing,
+ add_teeth=kwargs.get("add_teeth", True),
+ teeth_bs_flag=kwargs.get("teeth_bs_flag", False),
+ oral_mesh_flag=kwargs.get("oral_mesh_flag", False),
+ use_mesh_shading=kwargs.get('use_mesh_shading', False),
+ render_rgb=kwargs.get("render_rgb", True),
+ )
+
+ def get_last_layer(self):
+ return self.renderer.gs_net.out_layers["shs"].weight
+
+ @staticmethod
+ def _encoder_fn(encoder_type: str):
+ encoder_type = encoder_type.lower()
+ assert encoder_type in ['dino', 'dinov2', 'dinov2_unet', 'resunet', 'dinov2_featup', 'dinov2_dpt', 'dinov2_fusion'], "Unsupported encoder type"
+ if encoder_type == 'dino':
+ from .encoders.dino_wrapper import DinoWrapper
+ # logger.info("Using DINO as the encoder")
+ return DinoWrapper
+ elif encoder_type == 'dinov2':
+ from .encoders.dinov2_wrapper import Dinov2Wrapper
+ # logger.info("Using DINOv2 as the encoder")
+ return Dinov2Wrapper
+ elif encoder_type == 'dinov2_unet':
+ from .encoders.dinov2_unet_wrapper import Dinov2UnetWrapper
+ # logger.info("Using Dinov2Unet as the encoder")
+ return Dinov2UnetWrapper
+ elif encoder_type == 'resunet':
+ from .encoders.xunet_wrapper import XnetWrapper
+ # logger.info("Using XnetWrapper as the encoder")
+ return XnetWrapper
+ elif encoder_type == 'dinov2_featup':
+ from .encoders.dinov2_featup_wrapper import Dinov2FeatUpWrapper
+ # logger.info("Using Dinov2FeatUpWrapper as the encoder")
+ return Dinov2FeatUpWrapper
+ elif encoder_type == 'dinov2_dpt':
+ from .encoders.dinov2_dpt_wrapper import Dinov2DPTWrapper
+ # logger.info("Using Dinov2DPTWrapper as the encoder")
+ return Dinov2DPTWrapper
+ elif encoder_type == 'dinov2_fusion':
+ from .encoders.dinov2_fusion_wrapper import Dinov2FusionWrapper
+ # logger.info("Using Dinov2FusionWrapper as the encoder")
+ return Dinov2FusionWrapper
+
+ def forward_transformer(self, image_feats, camera_embeddings, query_points, query_feats=None):
+ # assert image_feats.shape[0] == camera_embeddings.shape[0], \
+ # "Batch size mismatch for image_feats and camera_embeddings!"
+ B = image_feats.shape[0]
+ if self.latent_query_points_type == "embedding":
+ range_ = torch.arange(self.num_pcl, device=image_feats.device)
+ x = self.pcl_embeddings(range_).unsqueeze(0).repeat((B, 1, 1)) # [B, L, D]
+
+ elif self.latent_query_points_type.startswith("flame"):
+ x = self.pcl_embed(self.pcl_embeddings.unsqueeze(0)).repeat((B, 1, 1)) # [B, L, D]
+
+ elif self.latent_query_points_type.startswith("e2e_flame"):
+ x = self.pcl_embed(query_points) # [B, L, D]
+
+ x = x.to(image_feats.dtype)
+ if query_feats is not None:
+ x = x + query_feats.to(image_feats.dtype)
+ x = self.transformer(
+ x,
+ cond=image_feats,
+ mod=camera_embeddings,
+ ) # [B, L, D]
+ # x = x.to(image_feats.dtype)
+ return x
+
+ def forward_encode_image(self, image):
+ # encode image
+ if self.training and self.encoder_gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ image_feats = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.encoder),
+ image,
+ **ckpt_kwargs,
+ )
+ else:
+ image_feats = self.encoder(image)
+ return image_feats
+
+ @torch.compile
+ def forward_latent_points(self, image, camera, query_points=None, additional_features=None):
+ # image: [B, C_img, H_img, W_img]
+ # camera: [B, D_cam_raw]
+ B = image.shape[0]
+
+ # encode image
+ image_feats = self.forward_encode_image(image)
+
+ assert image_feats.shape[-1] == self.encoder_feat_dim, \
+ f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
+
+ if additional_features is not None and len(additional_features.keys()) > 0:
+ image_feats_bchw = rearrange(image_feats, "b (h w) c -> b c h w", h=int(math.sqrt(image_feats.shape[1])))
+ additional_features["source_image_feats"] = image_feats_bchw
+ proj_feats = self.renderer.get_batch_project_feats(None, query_points, additional_features=additional_features, feat_nms=['source_image_feats'], use_mesh=True)
+ query_feats = proj_feats['source_image_feats']
+ else:
+ query_feats = None
+ # # embed camera
+ # camera_embeddings = self.camera_embedder(camera)
+ # assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
+ # f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
+
+ # transformer generating latent points
+ tokens = self.forward_transformer(image_feats, camera_embeddings=None, query_points=query_points, query_feats=query_feats)
+
+ return tokens, image_feats
+
+ def forward(self, image, source_c2ws, source_intrs, render_c2ws, render_intrs, render_bg_colors, flame_params, source_flame_params=None, render_images=None, data=None):
+ # image: [B, N_ref, C_img, H_img, W_img]
+ # source_c2ws: [B, N_ref, 4, 4]
+ # source_intrs: [B, N_ref, 4, 4]
+ # render_c2ws: [B, N_source, 4, 4]
+ # render_intrs: [B, N_source, 4, 4]
+ # render_bg_colors: [B, N_source, 3]
+ # flame_params: Dict, e.g., pose_shape: [B, N_source, 21, 3], betas:[B, 100]
+ assert image.shape[0] == render_c2ws.shape[0], "Batch size mismatch for image and render_c2ws"
+ assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors"
+ assert image.shape[0] == flame_params["betas"].shape[0], "Batch size mismatch for image and flame_params"
+ assert image.shape[0] == flame_params["expr"].shape[0], "Batch size mismatch for image and flame_params"
+ assert len(flame_params["betas"].shape) == 2
+ render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int(render_intrs[0, 0, 0, 2] * 2)
+ query_points = None
+
+ if self.latent_query_points_type.startswith("e2e_flame"):
+ query_points, flame_params = self.renderer.get_query_points(flame_params,
+ device=image.device)
+
+ additional_features = {}
+
+ latent_points, image_feats = self.forward_latent_points(image[:, 0], camera=None, query_points=query_points, additional_features=additional_features) # [B, N, C]
+
+ additional_features.update({
+ "image_feats": image_feats, "image": image[:, 0],
+ })
+ image_feats_bchw = rearrange(image_feats, "b (h w) c -> b c h w", h=int(math.sqrt(image_feats.shape[1])))
+ additional_features["image_feats_bchw"] = image_feats_bchw
+
+ # render target views
+ render_results = self.renderer(gs_hidden_features=latent_points,
+ query_points=query_points,
+ flame_data=flame_params,
+ c2w=render_c2ws,
+ intrinsic=render_intrs,
+ height=render_h,
+ width=render_w,
+ background_color=render_bg_colors,
+ additional_features=additional_features
+ )
+
+ N, M = render_c2ws.shape[:2]
+ assert render_results['comp_rgb'].shape[0] in [N, N], "Batch size mismatch for render_results"
+ assert render_results['comp_rgb'].shape[1] in [M, M*2], "Number of rendered views should be consistent with render_cameras"
+
+ if self.use_conf_map:
+ b, v = render_images.shape[:2]
+ if self.conf_use_pred_img:
+ render_images = repeat(render_images, "b v c h w -> (b v r) c h w", r=2)
+ pred_images = rearrange(render_results['comp_rgb'].detach().clone(), "b v c h w -> (b v) c h w")
+ else:
+ render_images = rearrange(render_images, "b v c h w -> (b v) c h w")
+ pred_images = None
+ conf_sigma_l1, conf_sigma_percl = self.conf_net(render_images, pred_images) # Bx2xHxW
+ conf_sigma_l1 = rearrange(conf_sigma_l1, "(b v) c h w -> b v c h w", b=b, v=v)
+ conf_sigma_percl = rearrange(conf_sigma_percl, "(b v) c h w -> b v c h w", b=b, v=v)
+ conf_dict = {
+ "conf_sigma_l1": conf_sigma_l1,
+ "conf_sigma_percl": conf_sigma_percl,
+ }
+ else:
+ conf_dict = {}
+ # self.conf_sigma_l1 = conf_sigma_l1[:,:1]
+ # self.conf_sigma_l1_flip = conf_sigma_l1[:,1:]
+ # self.conf_sigma_percl = conf_sigma_percl[:,:1]
+ # self.conf_sigma_percl_flip = conf_sigma_percl[:,1:]
+
+ return {
+ 'latent_points': latent_points,
+ **render_results,
+ **conf_dict,
+ }
+
+ @torch.no_grad()
+ def infer_single_view(self, image, source_c2ws, source_intrs, render_c2ws,
+ render_intrs, render_bg_colors, flame_params):
+ # image: [B, N_ref, C_img, H_img, W_img]
+ # source_c2ws: [B, N_ref, 4, 4]
+ # source_intrs: [B, N_ref, 4, 4]
+ # render_c2ws: [B, N_source, 4, 4]
+ # render_intrs: [B, N_source, 4, 4]
+ # render_bg_colors: [B, N_source, 3]
+ # flame_params: Dict, e.g., pose_shape: [B, N_source, 21, 3], betas:[B, 100]
+ assert image.shape[0] == render_c2ws.shape[0], "Batch size mismatch for image and render_c2ws"
+ assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors"
+ assert image.shape[0] == flame_params["betas"].shape[0], "Batch size mismatch for image and flame_params"
+ assert image.shape[0] == flame_params["expr"].shape[0], "Batch size mismatch for image and flame_params"
+ assert len(flame_params["betas"].shape) == 2
+ render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int(render_intrs[0, 0, 0, 2] * 2)
+ assert image.shape[0] == 1
+ num_views = render_c2ws.shape[1]
+ query_points = None
+
+ if self.latent_query_points_type.startswith("e2e_flame"):
+ query_points, flame_params = self.renderer.get_query_points(flame_params,
+ device=image.device)
+ latent_points, image_feats = self.forward_latent_points(image[:, 0], camera=None, query_points=query_points) # [B, N, C]
+ image_feats_bchw = rearrange(image_feats, "b (h w) c -> b c h w", h=int(math.sqrt(image_feats.shape[1])))
+
+ gs_model_list, query_points, flame_params, _ = self.renderer.forward_gs(gs_hidden_features=latent_points,
+ query_points=query_points,
+ flame_data=flame_params,
+ additional_features={"image_feats": image_feats, "image": image[:, 0], "image_feats_bchw": image_feats_bchw})
+
+ render_res_list = []
+ for view_idx in range(num_views):
+ render_res = self.renderer.forward_animate_gs(gs_model_list,
+ query_points,
+ self.renderer.get_single_view_smpl_data(flame_params, view_idx),
+ render_c2ws[:, view_idx:view_idx+1],
+ render_intrs[:, view_idx:view_idx+1],
+ render_h,
+ render_w,
+ render_bg_colors[:, view_idx:view_idx+1])
+ render_res_list.append(render_res)
+
+ out = defaultdict(list)
+ for res in render_res_list:
+ for k, v in res.items():
+ out[k].append(v)
+ for k, v in out.items():
+ # print(f"out key:{k}")
+ if isinstance(v[0], torch.Tensor):
+ out[k] = torch.concat(v, dim=1)
+ if k in ["comp_rgb", "comp_mask", "comp_depth"]:
+ out[k] = out[k][0].permute(0, 2, 3, 1) # [1, Nv, 3, H, W] -> [Nv, 3, H, W] - > [Nv, H, W, 3]
+ else:
+ out[k] = v
+ out['cano_gs_lst'] = gs_model_list
+ return out
+
diff --git a/lam/models/modulate.py b/lam/models/modulate.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d2a0f0240cc1d596a9a544d56eac5ee7e03cc7d
--- /dev/null
+++ b/lam/models/modulate.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+
+class ModLN(nn.Module):
+ """
+ Modulation with adaLN.
+
+ References:
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
+ """
+ def __init__(self, inner_dim: int, mod_dim: int, eps: float):
+ super().__init__()
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(mod_dim, inner_dim * 2),
+ )
+
+ @staticmethod
+ def modulate(x, shift, scale):
+ # x: [N, L, D]
+ # shift, scale: [N, D]
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
+ return self.modulate(self.norm(x), shift, scale) # [N, L, D]
diff --git a/lam/models/rendering/__init__.py b/lam/models/rendering/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/lam/models/rendering/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/lam/models/rendering/flame_model/flame.py b/lam/models/rendering/flame_model/flame.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc2158f54dd65b7d795f2984dcf42d49a38c55d
--- /dev/null
+++ b/lam/models/rendering/flame_model/flame.py
@@ -0,0 +1,1545 @@
+# Code heavily inspired by https://github.com/HavenFeng/photometric_optimization/blob/master/models/FLAME.py.
+# Please consider citing their work if you find this code useful. The code is subject to the license available via
+# https://github.com/vchoutas/flame/edit/master/LICENSE
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+
+from lam.models.rendering.flame_model.lbs import lbs, vertices2landmarks, blend_shapes, vertices2joints
+from lam.models.rendering.flame_model.lbs import batch_rigid_transform, batch_rodrigues
+
+import os
+import json
+import torch
+import trimesh
+import torch.nn as nn
+import numpy as np
+import pickle
+from collections import defaultdict
+try:
+ from pytorch3d.io import load_obj
+except ImportError:
+ from utils.pytorch3d_load_obj import load_obj
+
+from pytorch3d.structures import Meshes
+from pytorch3d.ops import SubdivideMeshes
+from einops import rearrange, repeat
+from lam.models.rendering.utils.mesh_utils import compute_face_normals, compute_face_orientation
+from lam.models.rendering.utils.uv_utils import (
+ gen_tritex,
+ uniform_sampling_barycoords,
+ reweight_uvcoords_by_barycoords,
+ reweight_verts_by_barycoords
+)
+from pytorch3d.transforms import (
+ axis_angle_to_quaternion,
+ quaternion_to_axis_angle,
+ matrix_to_quaternion,
+ quaternion_multiply,
+)
+import functools
+from lam.models.rendering.gaussian_model import GaussianModel
+import torch.nn.functional as F
+
+
+def to_tensor(array, dtype=torch.float32):
+ if "torch.tensor" not in str(type(array)):
+ return torch.tensor(array, dtype=dtype)
+
+
+def to_np(array, dtype=np.float32):
+ if "scipy.sparse" in str(type(array)):
+ array = array.todense()
+ return np.array(array, dtype=dtype)
+
+
+class Struct(object):
+ def __init__(self, **kwargs):
+ for key, val in kwargs.items():
+ setattr(self, key, val)
+
+def face_vertices(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of faces, 3, 3]
+ """
+ assert vertices.ndimension() == 3
+ assert faces.ndimension() == 3
+ assert vertices.shape[0] == faces.shape[0]
+ assert vertices.shape[2] == 3
+ assert faces.shape[2] == 3
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
+
+class FlameHead(nn.Module):
+ """
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+
+ def __init__(
+ self,
+ shape_params,
+ expr_params,
+ flame_model_path=None,
+ flame_lmk_embedding_path=None,
+ flame_template_mesh_path=None,
+ flame_parts_path=None,
+ include_mask=True,
+ add_teeth=True,
+ add_shoulder=False,
+ teeth_bs_flag = False,
+ oral_mesh_flag = False,
+ ):
+ super().__init__()
+
+ self.n_shape_params = shape_params
+ self.n_expr_params = expr_params
+ self.use_teeth = add_teeth
+ self.flame_model_dir = os.path.dirname(flame_model_path)
+
+ with open(flame_model_path, "rb") as f:
+ ss = pickle.load(f, encoding="latin1")
+ flame_model = Struct(**ss)
+
+ self.dtype = torch.float32
+ # The vertices of the template model
+ self.register_buffer(
+ "v_template", to_tensor(to_np(flame_model.v_template), dtype=self.dtype)
+ )
+
+ # The shape components and expression
+ shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
+ shapedirs = torch.cat(
+ [shapedirs[:, :, :shape_params], shapedirs[:, :, 300 : 300 + expr_params]],
+ 2,
+ )
+ self.register_buffer("shapedirs", shapedirs)
+
+ # The pose components
+ num_pose_basis = flame_model.posedirs.shape[-1]
+ posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
+ self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
+ #
+ self.register_buffer(
+ "J_regressor", to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)
+ )
+ parents = to_tensor(to_np(flame_model.kintree_table[0])).long()
+ parents[0] = -1
+ self.register_buffer("parents", parents)
+ self.register_buffer(
+ "lbs_weights", to_tensor(to_np(flame_model.weights), dtype=self.dtype)
+ )
+
+ # Landmark embeddings for FLAME
+ lmk_embeddings = np.load(
+ flame_lmk_embedding_path, allow_pickle=True, encoding="latin1"
+ )
+ lmk_embeddings = lmk_embeddings[()]
+ self.register_buffer(
+ "full_lmk_faces_idx",
+ torch.tensor(lmk_embeddings["full_lmk_faces_idx"], dtype=torch.long),
+ )
+ self.register_buffer(
+ "full_lmk_bary_coords",
+ torch.tensor(lmk_embeddings["full_lmk_bary_coords"], dtype=self.dtype),
+ )
+
+ neck_kin_chain = []
+ NECK_IDX = 1
+ curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
+ while curr_idx != -1:
+ neck_kin_chain.append(curr_idx)
+ curr_idx = self.parents[curr_idx]
+ self.register_buffer("neck_kin_chain", torch.stack(neck_kin_chain))
+
+ # add faces and uvs
+ verts, faces, aux = load_obj(flame_template_mesh_path, load_textures=False)
+
+ vertex_uvs = aux.verts_uvs
+ face_uvs_idx = faces.textures_idx # index into verts_uvs
+
+ pad = torch.ones(vertex_uvs.shape[0], 1)
+ vertex_uvs = torch.cat([vertex_uvs, pad], dim=-1)
+
+ face_uv_coords = face_vertices(vertex_uvs[None], face_uvs_idx[None])[0]
+ self.register_buffer("face_uvcoords", face_uv_coords, persistent=False)
+ self.register_buffer("faces", faces.verts_idx, persistent=False)
+
+ self.register_buffer("verts_uvs", aux.verts_uvs, persistent=False)
+ self.register_buffer("textures_idx", faces.textures_idx, persistent=False)
+
+ # Cal vertex mean uvs from faces for vertex uvs, so as to use FLAME subdivision.
+ vtx_ids = rearrange(self.faces, "nf nv -> (nf nv)")
+ vtx_ids = repeat(vtx_ids, "n -> n c", c=3)
+ uvs = rearrange(self.face_uvcoords, "nf nv c-> (nf nv) c")
+ N = self.v_template.shape[0]
+ sums = torch.zeros((N, 3), dtype=uvs.dtype, device=uvs.device)
+ counts = torch.zeros((N), dtype=torch.int64, device=uvs.device)
+ sums.scatter_add_(0, vtx_ids, uvs)
+ one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(uvs.device)
+ counts.scatter_add_(0, vtx_ids[:, 0], one_hot)
+ clamp_counts = counts.clamp(min=1)
+ vtx_uvs = sums / clamp_counts.view(-1, 1)
+
+ # Check our template mesh faces match those of FLAME:
+ assert (self.faces==torch.from_numpy(flame_model.f.astype('int64'))).all()
+ if include_mask:
+ self.mask = FlameMask(
+ flame_parts_path=flame_parts_path,
+ faces=self.faces,
+ faces_t=self.textures_idx,
+ num_verts=self.v_template.shape[0],
+ num_faces=self.faces.shape[0],
+ )
+
+ if self.use_teeth:
+ self.add_teeth()
+
+ self.teeth_bs_flag = teeth_bs_flag
+ if self.teeth_bs_flag:
+ self.add_teeth_bs()
+
+ if self.use_teeth:
+ pad = torch.ones(self.teeth_verts_uvs.shape[0], 1)
+ teeth_vtx_uvs = torch.cat([self.teeth_verts_uvs, pad], dim=-1)
+ vtx_uvs = torch.cat((vtx_uvs, teeth_vtx_uvs), dim=0)
+
+ self.add_shoulder = add_shoulder
+ if (add_shoulder):
+ shoulder_mesh = trimesh.load(os.path.join(self.flame_model_dir, 'shoulder_mesh.obj'))
+ self.v_shoulder = torch.tensor(shoulder_mesh.vertices).float()
+ self.f_shoulder = torch.tensor(shoulder_mesh.faces) + self.v_template.shape[0]
+
+ self.v_template = torch.cat([self.v_template, self.v_shoulder], dim=0)
+ self.faces = torch.cat([self.faces,self.f_shoulder])
+
+ self.oral_mesh_flag = oral_mesh_flag
+ if (self.oral_mesh_flag):
+ oral_mesh_path = os.path.join(self.flame_model_dir, 'oral_jawopen0p5.obj')
+ assert os.path.exists(oral_mesh_path), "oral_mesh_path {} is not exist!".format(oral_mesh_path)
+ oral_mesh = trimesh.load(oral_mesh_path)
+ v_oral = torch.tensor(oral_mesh.vertices).float()
+ f_oral = torch.tensor(oral_mesh.faces) + self.v_template.shape[0]
+
+ num_verts_oral = v_oral.shape[0]
+
+ shapedirs_shoulder = torch.zeros((num_verts_oral, 3, self.shapedirs.shape[2])).float()
+ self.shapedirs = torch.concat([self.shapedirs, shapedirs_shoulder], dim=0)
+
+ # posedirs set to zero
+ num_verts_orig = self.v_template.shape[0]
+ posedirs = self.posedirs.reshape(len(self.parents) - 1, 9, num_verts_orig, 3) # (J*9, V*3) -> (J, 9, V, 3)
+ posedirs = torch.cat([posedirs, torch.zeros_like(posedirs[:, :, :num_verts_oral])],
+ dim=2) # (J, 9, V+num_verts_teeth, 3)
+ self.posedirs = posedirs.reshape((len(self.parents) - 1) * 9,
+ (num_verts_orig + num_verts_oral) * 3) # (J*9, (V+num_verts_teeth)*3)
+
+ # J_regressor set to zero
+ self.J_regressor = torch.cat([self.J_regressor, torch.zeros_like(self.J_regressor[:, :num_verts_oral])],
+ dim=1) # (5, J) -> (5, J+num_verts_teeth)
+
+ # lbs_weights manually set
+ self.lbs_weights = torch.cat([self.lbs_weights, torch.zeros_like(self.lbs_weights[:num_verts_oral])],
+ dim=0) # (V, 5) -> (V+num_verts_teeth, 5)
+
+ vid_oral = torch.arange(0, num_verts_oral) + num_verts_orig
+ self.lbs_weights[vid_oral, 1] = 1
+
+ self.v_template = torch.cat([self.v_template, v_oral], dim=0)
+ self.faces = torch.cat([self.faces, f_oral], dim=0)
+
+ def add_teeth_bs(self):
+ teeth_bs_path = os.path.join(self.flame_model_dir, 'teeth_blendshape.json')
+ assert os.path.exists(teeth_bs_path), "Path {} is not exist!".format(teeth_bs_path)
+ with open(teeth_bs_path, 'r') as f:
+ bs_data = json.load(f)
+ sorted_keys = sorted(bs_data)
+ bs_data = {key: bs_data[key] for key in sorted_keys}
+ all_bs = []
+ for bs_name in bs_data:
+ current_bs = torch.from_numpy(np.array(bs_data[bs_name])).float()
+ all_verts_bs = torch.zeros((5023,3))
+ all_verts_bs = torch.cat([all_verts_bs,current_bs],dim=0)[None,...]
+ all_bs.append(all_verts_bs)
+ all_bs = torch.cat(all_bs,dim=0).permute(1,2,0)
+ self.shapedirs = torch.cat([self.shapedirs,all_bs],dim=2)
+
+ def add_teeth(self):
+ # get reference vertices from lips
+ vid_lip_outside_ring_upper = self.mask.get_vid_by_region(['lip_outside_ring_upper'], keep_order=True)
+
+ vid_lip_outside_ring_lower = self.mask.get_vid_by_region(['lip_outside_ring_lower'], keep_order=True)
+
+ v_lip_upper = self.v_template[vid_lip_outside_ring_upper]
+ v_lip_lower = self.v_template[vid_lip_outside_ring_lower]
+
+ # construct vertices for teeth
+ mean_dist = (v_lip_upper - v_lip_lower).norm(dim=-1, keepdim=True).mean()
+ v_teeth_middle = (v_lip_upper + v_lip_lower) / 2
+ v_teeth_middle[:, 1] = v_teeth_middle[:, [1]].mean(dim=0, keepdim=True)
+ # v_teeth_middle[:, 2] -= mean_dist * 2.5 # how far the teeth are from the lips
+ # v_teeth_middle[:, 2] -= mean_dist * 2 # how far the teeth are from the lips
+ v_teeth_middle[:, 2] -= mean_dist * 1.5 # how far the teeth are from the lips
+
+ # upper, front
+ v_teeth_upper_edge = v_teeth_middle.clone() + torch.tensor([[0, mean_dist, 0]])*0.1
+ v_teeth_upper_root = v_teeth_upper_edge + torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # lower, front
+ v_teeth_lower_edge = v_teeth_middle.clone() - torch.tensor([[0, mean_dist, 0]])*0.1
+ # v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.2 # slightly move the lower teeth to the back
+ v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.4 # slightly move the lower teeth to the back
+ v_teeth_lower_root = v_teeth_lower_edge - torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # thickness = mean_dist * 0.5
+ thickness = mean_dist * 1.
+ # upper, back
+ v_teeth_upper_root_back = v_teeth_upper_root.clone()
+ v_teeth_upper_edge_back = v_teeth_upper_edge.clone()
+ v_teeth_upper_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_upper_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # lower, back
+ v_teeth_lower_root_back = v_teeth_lower_root.clone()
+ v_teeth_lower_edge_back = v_teeth_lower_edge.clone()
+ v_teeth_lower_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_lower_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # concatenate to v_template
+ num_verts_orig = self.v_template.shape[0]
+ v_teeth = torch.cat([
+ v_teeth_upper_root, # num_verts_orig + 0-14
+ v_teeth_lower_root, # num_verts_orig + 15-29
+ v_teeth_upper_edge, # num_verts_orig + 30-44
+ v_teeth_lower_edge, # num_verts_orig + 45-59
+ v_teeth_upper_root_back, # num_verts_orig + 60-74
+ v_teeth_upper_edge_back, # num_verts_orig + 75-89
+ v_teeth_lower_root_back, # num_verts_orig + 90-104
+ v_teeth_lower_edge_back, # num_verts_orig + 105-119
+ ], dim=0)
+ num_verts_teeth = v_teeth.shape[0]
+ self.v_template = torch.cat([self.v_template, v_teeth], dim=0)
+
+ vid_teeth_upper_root = torch.arange(0, 15) + num_verts_orig
+ vid_teeth_lower_root = torch.arange(15, 30) + num_verts_orig
+ vid_teeth_upper_edge = torch.arange(30, 45) + num_verts_orig
+ vid_teeth_lower_edge = torch.arange(45, 60) + num_verts_orig
+ vid_teeth_upper_root_back = torch.arange(60, 75) + num_verts_orig
+ vid_teeth_upper_edge_back = torch.arange(75, 90) + num_verts_orig
+ vid_teeth_lower_root_back = torch.arange(90, 105) + num_verts_orig
+ vid_teeth_lower_edge_back = torch.arange(105, 120) + num_verts_orig
+
+ vid_teeth_upper = torch.cat([vid_teeth_upper_root, vid_teeth_upper_edge, vid_teeth_upper_root_back, vid_teeth_upper_edge_back], dim=0)
+ vid_teeth_lower = torch.cat([vid_teeth_lower_root, vid_teeth_lower_edge, vid_teeth_lower_root_back, vid_teeth_lower_edge_back], dim=0)
+ vid_teeth = torch.cat([vid_teeth_upper, vid_teeth_lower], dim=0)
+
+ # update vertex masks
+ self.mask.v.register_buffer("teeth_upper", vid_teeth_upper)
+ self.mask.v.register_buffer("teeth_lower", vid_teeth_lower)
+ self.mask.v.register_buffer("teeth", vid_teeth)
+ self.mask.v.left_half = torch.cat([
+ self.mask.v.left_half,
+ torch.tensor([
+ 5023, 5024, 5025, 5026, 5027, 5028, 5029, 5030, 5038, 5039, 5040, 5041, 5042, 5043, 5044, 5045, 5053, 5054, 5055, 5056, 5057, 5058, 5059, 5060, 5068, 5069, 5070, 5071, 5072, 5073, 5074, 5075, 5083, 5084, 5085, 5086, 5087, 5088, 5089, 5090, 5098, 5099, 5100, 5101, 5102, 5103, 5104, 5105, 5113, 5114, 5115, 5116, 5117, 5118, 5119, 5120, 5128, 5129, 5130, 5131, 5132, 5133, 5134, 5135,
+ ])], dim=0)
+
+ self.mask.v.right_half = torch.cat([
+ self.mask.v.right_half,
+ torch.tensor([
+ 5030, 5031, 5032, 5033, 5034, 5035, 5036, 5037, 5045, 5046, 5047, 5048, 5049, 5050, 5051, 5052, 5060, 5061, 5062, 5063, 5064, 5065, 5066, 5067, 5075, 5076, 5077, 5078, 5079, 5080, 5081, 5082, 5090, 5091, 5092, 5093, 5094, 5095, 5097, 5105, 5106, 5107, 5108, 5109, 5110, 5111, 5112, 5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127, 5135, 5136, 5137, 5138, 5139, 5140, 5141, 5142,
+ ])], dim=0)
+
+ # construct uv vertices for teeth
+ u = torch.linspace(0.62, 0.38, 15)
+ v = torch.linspace(1-0.0083, 1-0.0425, 7)
+ # v = v[[0, 2, 1, 1]]
+ # v = v[[0, 3, 1, 4, 3, 2, 6, 5]]
+ v = v[[3, 2, 0, 1, 3, 4, 6, 5]] # TODO: with this order, teeth_lower is not rendered correctly in the uv space
+ uv = torch.stack(torch.meshgrid(u, v, indexing='ij'), dim=-1).permute(1, 0, 2).reshape(num_verts_teeth, 2) # (#num_teeth, 2)
+ num_verts_uv_orig = self.verts_uvs.shape[0]
+ num_verts_uv_teeth = uv.shape[0]
+ self.verts_uvs = torch.cat([self.verts_uvs, uv], dim=0)
+ self.teeth_verts_uvs = uv
+
+ # shapedirs copy from lips
+ self.shapedirs = torch.cat([self.shapedirs, torch.zeros_like(self.shapedirs[:num_verts_teeth])], dim=0)
+ shape_dirs_mean = (self.shapedirs[vid_lip_outside_ring_upper, :, :self.n_shape_params] + self.shapedirs[vid_lip_outside_ring_lower, :, :self.n_shape_params]) / 2
+ self.shapedirs[vid_teeth_upper_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+
+ # posedirs set to zero
+ posedirs = self.posedirs.reshape(len(self.parents)-1, 9, num_verts_orig, 3) # (J*9, V*3) -> (J, 9, V, 3)
+ posedirs = torch.cat([posedirs, torch.zeros_like(posedirs[:, :, :num_verts_teeth])], dim=2) # (J, 9, V+num_verts_teeth, 3)
+ self.posedirs = posedirs.reshape((len(self.parents)-1)*9, (num_verts_orig+num_verts_teeth)*3) # (J*9, (V+num_verts_teeth)*3)
+
+ # J_regressor set to zero
+ self.J_regressor = torch.cat([self.J_regressor, torch.zeros_like(self.J_regressor[:, :num_verts_teeth])], dim=1) # (5, J) -> (5, J+num_verts_teeth)
+
+ # lbs_weights manually set
+ self.lbs_weights = torch.cat([self.lbs_weights, torch.zeros_like(self.lbs_weights[:num_verts_teeth])], dim=0) # (V, 5) -> (V+num_verts_teeth, 5)
+ self.lbs_weights[vid_teeth_upper, 1] += 1 # move with neck
+ self.lbs_weights[vid_teeth_lower, 2] += 1 # move with jaw
+
+ # add faces for teeth
+ f_teeth_upper = torch.tensor([
+ [0, 31, 30], #0
+ [0, 1, 31], #1
+ [1, 32, 31], #2
+ [1, 2, 32], #3
+ [2, 33, 32], #4
+ [2, 3, 33], #5
+ [3, 34, 33], #6
+ [3, 4, 34], #7
+ [4, 35, 34], #8
+ [4, 5, 35], #9
+ [5, 36, 35], #10
+ [5, 6, 36], #11
+ [6, 37, 36], #12
+ [6, 7, 37], #13
+ [7, 8, 37], #14
+ [8, 38, 37], #15
+ [8, 9, 38], #16
+ [9, 39, 38], #17
+ [9, 10, 39], #18
+ [10, 40, 39], #19
+ [10, 11, 40], #20
+ [11, 41, 40], #21
+ [11, 12, 41], #22
+ [12, 42, 41], #23
+ [12, 13, 42], #24
+ [13, 43, 42], #25
+ [13, 14, 43], #26
+ [14, 44, 43], #27
+ [60, 75, 76], # 56
+ [60, 76, 61], # 57
+ [61, 76, 77], # 58
+ [61, 77, 62], # 59
+ [62, 77, 78], # 60
+ [62, 78, 63], # 61
+ [63, 78, 79], # 62
+ [63, 79, 64], # 63
+ [64, 79, 80], # 64
+ [64, 80, 65], # 65
+ [65, 80, 81], # 66
+ [65, 81, 66], # 67
+ [66, 81, 82], # 68
+ [66, 82, 67], # 69
+ [67, 82, 68], # 70
+ [68, 82, 83], # 71
+ [68, 83, 69], # 72
+ [69, 83, 84], # 73
+ [69, 84, 70], # 74
+ [70, 84, 85], # 75
+ [70, 85, 71], # 76
+ [71, 85, 86], # 77
+ [71, 86, 72], # 78
+ [72, 86, 87], # 79
+ [72, 87, 73], # 80
+ [73, 87, 88], # 81
+ [73, 88, 74], # 82
+ [74, 88, 89], # 83
+ [75, 30, 76], # 84
+ [76, 30, 31], # 85
+ [76, 31, 77], # 86
+ [77, 31, 32], # 87
+ [77, 32, 78], # 88
+ [78, 32, 33], # 89
+ [78, 33, 79], # 90
+ [79, 33, 34], # 91
+ [79, 34, 80], # 92
+ [80, 34, 35], # 93
+ [80, 35, 81], # 94
+ [81, 35, 36], # 95
+ [81, 36, 82], # 96
+ [82, 36, 37], # 97
+ [82, 37, 38], # 98
+ [82, 38, 83], # 99
+ [83, 38, 39], # 100
+ [83, 39, 84], # 101
+ [84, 39, 40], # 102
+ [84, 40, 85], # 103
+ [85, 40, 41], # 104
+ [85, 41, 86], # 105
+ [86, 41, 42], # 106
+ [86, 42, 87], # 107
+ [87, 42, 43], # 108
+ [87, 43, 88], # 109
+ [88, 43, 44], # 110
+ [88, 44, 89], # 111
+ ])
+ f_teeth_lower = torch.tensor([
+ [45, 46, 15], # 28
+ [46, 16, 15], # 29
+ [46, 47, 16], # 30
+ [47, 17, 16], # 31
+ [47, 48, 17], # 32
+ [48, 18, 17], # 33
+ [48, 49, 18], # 34
+ [49, 19, 18], # 35
+ [49, 50, 19], # 36
+ [50, 20, 19], # 37
+ [50, 51, 20], # 38
+ [51, 21, 20], # 39
+ [51, 52, 21], # 40
+ [52, 22, 21], # 41
+ [52, 23, 22], # 42
+ [52, 53, 23], # 43
+ [53, 24, 23], # 44
+ [53, 54, 24], # 45
+ [54, 25, 24], # 46
+ [54, 55, 25], # 47
+ [55, 26, 25], # 48
+ [55, 56, 26], # 49
+ [56, 27, 26], # 50
+ [56, 57, 27], # 51
+ [57, 28, 27], # 52
+ [57, 58, 28], # 53
+ [58, 29, 28], # 54
+ [58, 59, 29], # 55
+ [90, 106, 105], # 112
+ [90, 91, 106], # 113
+ [91, 107, 106], # 114
+ [91, 92, 107], # 115
+ [92, 108, 107], # 116
+ [92, 93, 108], # 117
+ [93, 109, 108], # 118
+ [93, 94, 109], # 119
+ [94, 110, 109], # 120
+ [94, 95, 110], # 121
+ [95, 111, 110], # 122
+ [95, 96, 111], # 123
+ [96, 112, 111], # 124
+ [96, 97, 112], # 125
+ [97, 98, 112], # 126
+ [98, 113, 112], # 127
+ [98, 99, 113], # 128
+ [99, 114, 113], # 129
+ [99, 100, 114], # 130
+ [100, 115, 114], # 131
+ [100, 101, 115], # 132
+ [101, 116, 114], # 133
+ [101, 102, 116], # 134
+ [102, 117, 116], # 135
+ [102, 103, 117], # 136
+ [103, 118, 117], # 137
+ [103, 104, 118], # 138
+ [104, 119, 118], # 139
+ [105, 106, 45], # 140
+ [106, 46, 45], # 141
+ [106, 107, 46], # 142
+ [107, 47, 46], # 143
+ [107, 108, 47], # 144
+ [108, 48, 47], # 145
+ [108, 109, 48], # 146
+ [109, 49, 48], # 147
+ [109, 110, 49], # 148
+ [110, 50, 49], # 149
+ [110, 111, 50], # 150
+ [111, 51, 50], # 151
+ [111, 112, 51], # 152
+ [112, 52, 51], # 153
+ [112, 53, 52], # 154
+ [112, 113, 53], # 155
+ [113, 54, 53], # 156
+ [113, 114, 54], # 157
+ [114, 55, 54], # 158
+ [114, 115, 55], # 159
+ [115, 56, 55], # 160
+ [115, 116, 56], # 161
+ [116, 57, 56], # 162
+ [116, 117, 57], # 163
+ [117, 58, 57], # 164
+ [117, 118, 58], # 165
+ [118, 59, 58], # 166
+ [118, 119, 59], # 167
+ ])
+ self.faces = torch.cat([self.faces, f_teeth_upper+num_verts_orig, f_teeth_lower+num_verts_orig], dim=0)
+ self.textures_idx = torch.cat([self.textures_idx, f_teeth_upper+num_verts_uv_orig, f_teeth_lower+num_verts_uv_orig], dim=0)
+
+ self.mask.update(self.faces, self.textures_idx)
+
+ def forward(
+ self,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = shape.shape[0]
+
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if(self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0]-self.v_shoulder.shape[0])].unsqueeze(0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped_woexpr = template_vertices + blend_shapes(torch.cat([betas[:, :self.n_shape_params],
+ torch.zeros_like(betas[:, self.n_shape_params:])],
+ dim=1), self.shapedirs)
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:,:(self.v_template.shape[0]-self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ vertices, J, mat_rot = lbs(
+ full_pose,
+ v_shaped,
+ self.posedirs,
+ self.J_regressor,
+ self.parents,
+ self.lbs_weights,
+ dtype=self.dtype,
+ )
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] =vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = v_shaped_woexpr
+ ret_vals["cano_with_expr"] = v_shaped
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+
+
+class FlameHeadSubdivided(FlameHead):
+ """
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+
+ def __init__(
+ self,
+ shape_params,
+ expr_params,
+ flame_model_path=None,
+ flame_lmk_embedding_path=None,
+ flame_template_mesh_path=None,
+ flame_parts_path=None,
+ include_mask=True,
+ add_teeth=True,
+ add_shoulder=False,
+ subdivide_num=0,
+ teeth_bs_flag = False,
+ oral_mesh_flag = False,
+ ):
+ super().__init__(shape_params=shape_params,
+ expr_params=expr_params,
+ flame_model_path=flame_model_path,
+ flame_lmk_embedding_path=flame_lmk_embedding_path,
+ flame_template_mesh_path=flame_template_mesh_path,
+ include_mask=include_mask,
+ add_teeth=add_teeth,
+ add_shoulder=add_shoulder,
+ flame_parts_path=flame_parts_path,
+ teeth_bs_flag = teeth_bs_flag,
+ oral_mesh_flag = oral_mesh_flag,
+ )
+
+ # subdivider
+ self.subdivide_num = subdivide_num
+ self.subdivider_list = self.get_subdivider(subdivide_num)
+ self.subdivider_cpu_list = self.get_subdivider_cpu(subdivide_num)
+ self.face_upsampled = self.subdivider_list[-1]._subdivided_faces.cpu().numpy() if self.subdivide_num > 0 else self.faces.numpy()
+ self.vertex_num_upsampled = int(np.max(self.face_upsampled) + 1)
+
+ self.vertex_num = self.v_template.shape[0]
+ self.joint_num = self.J_regressor.shape[0]
+ print(f"face_upsampled:{self.face_upsampled.shape}, face_ori:{self.faces.shape}, \
+ vertex_num_upsampled:{self.vertex_num_upsampled}, vertex_num_ori:{self.vertex_num}")
+
+ lbs_weights = self.lbs_weights.float()
+ posedirs = self.posedirs.permute(1, 0).reshape(self.vertex_num, 3 * (self.joint_num - 1) * 9)
+ shapedirs = self.shapedirs.view(self.vertex_num, 3 * (self.n_shape_params + self.n_expr_params + (4 if self.teeth_bs_flag else 0)))
+ J_regressor = self.J_regressor.permute(1, 0)
+
+ attributes = [lbs_weights, posedirs, shapedirs, J_regressor]
+ ret = self.upsample_mesh_cpu(self.v_template.float(), attributes,) # upsample with dummy vertex
+ v_template_upsampled, lbs_weights, posedirs, shapedirs, J_regressor = ret
+
+ posedirs = posedirs.reshape(self.vertex_num_upsampled * 3, (self.joint_num-1) * 9).permute(1, 0)
+ shapedirs = shapedirs.view(self.vertex_num_upsampled, 3 , (self.n_shape_params + self.n_expr_params + (4 if self.teeth_bs_flag else 0)))
+ J_regressor = J_regressor.permute(1, 0)
+
+ self.register_buffer('faces_up', torch.from_numpy(self.face_upsampled).to(shapedirs.device))
+ self.register_buffer('v_template_up', v_template_upsampled.contiguous())
+ self.register_buffer('lbs_weights_up', lbs_weights.contiguous())
+ self.register_buffer('shapedirs_up', shapedirs.contiguous())
+
+ def get_cano_verts(self, shape_params):
+ # TODO check
+ assert self.add_shoulder == False
+ batch_size = shape_params.shape[0]
+
+ template_vertices = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1)
+
+ v_shaped = template_vertices + blend_shapes(shape_params, self.shapedirs_up[:, :, :self.n_shape_params])
+
+ return v_shaped
+
+ def animation_forward(self,
+ v_cano,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ assert self.add_shoulder == False
+ assert static_offset is None
+
+ batch_size = shape.shape[0]
+
+ # step1. get animated_joint and corresponding transformed mat (Note not in upsampled space)
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if(self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0]-self.v_shoulder.shape[0])].unsqueeze(0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:,:(self.v_template.shape[0]-self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ A, J = self.get_transformed_mat(pose=full_pose, v_shaped=v_shaped, posedirs=self.posedirs,
+ parents=self.parents, J_regressor=self.J_regressor, pose2rot=True,
+ dtype=self.dtype)
+
+ # step2. v_cano_with_expr
+ v_cano_with_expr = v_cano + blend_shapes(expr, self.shapedirs_up[:, :, self.n_shape_params:])
+
+ # step3. lbs
+ vertices = self.skinning(v_posed=v_cano_with_expr, A=A, lbs_weights=self.lbs_weights_up, batch_size=batch_size,
+ num_joints=self.joint_num, dtype=self.dtype, device=full_pose.device)
+
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] =vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = v_cano
+ ret_vals["cano_with_expr"] = v_cano_with_expr
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+ def get_transformed_mat(self, pose, v_shaped, posedirs, parents, J_regressor, pose2rot, dtype):
+ batch_size = pose.shape[0]
+ device = pose.device
+
+ # Get the joints
+ # NxJx3 array
+ J = vertices2joints(J_regressor, v_shaped)
+
+ # 3. Add pose blend shapes
+ # N x J x 3 x 3
+ ident = torch.eye(3, dtype=dtype, device=device)
+ if pose2rot:
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view(
+ [batch_size, -1, 3, 3]
+ )
+
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
+ # (N x P) x (P, V * 3) -> N x V x 3
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
+ else:
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
+ rot_mats = pose.view(batch_size, -1, 3, 3)
+
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(
+ batch_size, -1, 3
+ )
+
+ v_posed = pose_offsets + v_shaped
+
+ # 4. Get the global joint location
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
+
+ return A, J_transformed
+
+ def skinning(self, v_posed, A, lbs_weights, batch_size, num_joints, dtype, device):
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ # num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones(
+ [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
+ )
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+ verts = v_homo[:, :, :3, 0]
+
+ return verts
+
+ def inverse_animation(self,
+ v_pose,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ assert self.add_shoulder == False
+ assert static_offset is None
+
+ batch_size = shape.shape[0]
+
+ # step1. get animated_joint and corresponding transformed mat (Note not in upsampled space)
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if(self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0]-self.v_shoulder.shape[0])].unsqueeze(0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:,:(self.v_template.shape[0]-self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ A, J = self.get_transformed_mat(pose=full_pose, v_shaped=v_shaped, posedirs=self.posedirs,
+ parents=self.parents, J_regressor=self.J_regressor, pose2rot=True,
+ dtype=self.dtype)
+
+ v_pose = v_pose - translation[:, None, :]
+
+ # inverse lbs
+ v_cano_with_expr = self.inverse_skinning(v_posed=v_pose, A=A, lbs_weights=self.lbs_weights_up, batch_size=batch_size,
+ num_joints=self.joint_num, dtype=self.dtype, device=full_pose.device)
+
+ # step2. v_cano
+ v_cano = v_cano_with_expr - blend_shapes(expr, self.shapedirs_up[:, :, self.n_shape_params:])
+
+ # step3. lbs
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+ v_cano = torch.cat([v_cano, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ v_cano = v_cano - J[:, [0]]
+ J = J - J[:, [0]]
+
+
+ ret_vals = {}
+ ret_vals["cano"] = v_cano
+
+ if return_verts_cano:
+ ret_vals["cano_with_expr"] = v_cano_with_expr
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = v_cano.shape[0]
+ landmarks = vertices2landmarks(
+ v_cano,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+ def inverse_skinning(self, v_posed, A, lbs_weights, batch_size, num_joints, dtype, device):
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ # num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones(
+ [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
+ )
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(torch.inverse(T), torch.unsqueeze(v_posed_homo, dim=-1))
+ verts = v_homo[:, :, :3, 0]
+
+ return verts
+
+ def forward(
+ self,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = shape.shape[0]
+
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if(self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0]-self.v_shoulder.shape[0])].unsqueeze(0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped_woexpr = template_vertices + blend_shapes(betas[:, :self.n_shape_params], self.shapedirs[:, :, :self.n_shape_params])
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:,:(self.v_template.shape[0]-self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ A, J = self.get_transformed_mat(pose=full_pose, v_shaped=v_shaped, posedirs=self.posedirs,
+ parents=self.parents, J_regressor=self.J_regressor, pose2rot=True,
+ dtype=self.dtype)
+
+ v_shaped_up = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1) + blend_shapes(betas, self.shapedirs_up)
+ vertices = self.skinning(v_posed=v_shaped_up, A=A, lbs_weights=self.lbs_weights_up, batch_size=batch_size,
+ num_joints=self.joint_num, dtype=self.dtype, device=full_pose.device)
+
+
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices, self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] =vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1) + blend_shapes(betas[:, :self.n_shape_params], self.shapedirs_up[:, :, :self.n_shape_params])
+ ret_vals["cano_with_expr"] = v_shaped_up
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+ def get_subdivider(self, subdivide_num):
+ vert = self.v_template.float().cuda()
+ face = torch.LongTensor(self.faces).cuda()
+ mesh = Meshes(vert[None,:,:], face[None,:,:])
+
+ if subdivide_num > 0:
+ subdivider_list = [SubdivideMeshes(mesh)]
+ for i in range(subdivide_num-1):
+ mesh = subdivider_list[-1](mesh)
+ subdivider_list.append(SubdivideMeshes(mesh))
+ else:
+ subdivider_list = [mesh]
+ return subdivider_list
+
+ def get_subdivider_cpu(self, subdivide_num):
+ vert = self.v_template.float()
+ face = torch.LongTensor(self.faces)
+ mesh = Meshes(vert[None,:,:], face[None,:,:])
+
+ if subdivide_num > 0:
+ subdivider_list = [SubdivideMeshes(mesh)]
+ for i in range(subdivide_num-1):
+ mesh = subdivider_list[-1](mesh)
+ subdivider_list.append(SubdivideMeshes(mesh))
+ else:
+ subdivider_list = [mesh]
+ return subdivider_list
+
+ def upsample_mesh_cpu(self, vert, feat_list=None):
+ face = torch.LongTensor(self.faces)
+ mesh = Meshes(vert[None,:,:], face[None,:,:])
+ if self.subdivide_num > 0:
+ if feat_list is None:
+ for subdivider in self.subdivider_cpu_list:
+ mesh = subdivider(mesh)
+ vert = mesh.verts_list()[0]
+ return vert
+ else:
+ feat_dims = [x.shape[1] for x in feat_list]
+ feats = torch.cat(feat_list,1)
+ for subdivider in self.subdivider_cpu_list:
+ mesh, feats = subdivider(mesh, feats)
+ vert = mesh.verts_list()[0]
+ feats = feats[0]
+ feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+ else:
+ if feat_list is None:
+ return vert
+ else:
+ return vert, *feat_list
+
+ def upsample_mesh(self, vert, feat_list=None, device="cuda"):
+ face = torch.LongTensor(self.faces).to(device)
+ mesh = Meshes(vert[None,:,:], face[None,:,:])
+ if self.subdivide_num > 0:
+ if feat_list is None:
+ for subdivider in self.subdivider_list:
+ mesh = subdivider(mesh)
+ vert = mesh.verts_list()[0]
+ return vert
+ else:
+ feat_dims = [x.shape[1] for x in feat_list]
+ feats = torch.cat(feat_list,1)
+ for subdivider in self.subdivider_list:
+ mesh, feats = subdivider(mesh, feats)
+ vert = mesh.verts_list()[0]
+ feats = feats[0]
+ feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+ else:
+ if feat_list is None:
+ return vert
+ else:
+ return vert, *feat_list
+
+
+ def upsample_mesh_batch(self, vert, device="cuda"):
+ if self.subdivide_num > 0:
+ face = torch.LongTensor(self.faces).to(device).unsqueeze(0).repeat(vert.shape[0], 1, 1)
+ mesh = Meshes(vert, face)
+ for subdivider in self.subdivider_list:
+ mesh = subdivider(mesh)
+ vert = torch.stack(mesh.verts_list(), dim=0)
+ else:
+ pass
+ return vert
+
+
+class BufferContainer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def __repr__(self):
+ main_str = super().__repr__() + '\n'
+ for name, buf in self.named_buffers():
+ main_str += f' {name:20}\t{buf.shape}\t{buf.dtype}\n'
+ return main_str
+
+ def __iter__(self):
+ for name, buf in self.named_buffers():
+ yield name, buf
+
+ def keys(self):
+ return [name for name, buf in self.named_buffers()]
+
+ def items(self):
+ return [(name, buf) for name, buf in self.named_buffers()]
+
+class FlameMask(nn.Module):
+ def __init__(
+ self,
+ flame_parts_path=None,
+ faces=None,
+ faces_t=None,
+ num_verts=5023,
+ num_faces=9976,
+ face_clusters=[],
+ ):
+ super().__init__()
+ self.faces = faces
+ self.faces_t = faces_t
+ self.face_clusters = face_clusters
+ self.num_verts = num_verts
+ if faces is not None:
+ self.num_faces = faces.shape[0]
+ else:
+ self.num_faces = num_faces
+
+ self.process_vertex_mask(flame_parts_path)
+
+ if self.faces is not None:
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def update(self, faces=None, faces_t=None, face_clusters=None):
+ """Update the faces properties when vertex masks are changed"""
+ if faces is not None:
+ self.faces = faces
+ self.num_faces = faces.shape[0]
+ if faces_t is not None:
+ self.faces_t = faces_t
+ if face_clusters is not None:
+ self.face_clusters = face_clusters
+
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def process_vertex_mask(self, flame_parts_path):
+ """Load the vertex masks from the FLAME model and add custom masks"""
+
+ part_masks = np.load(flame_parts_path, allow_pickle=True, encoding="latin1")
+ """ Available part masks from the FLAME model:
+ face, neck, scalp, boundary, right_eyeball, left_eyeball,
+ right_ear, left_ear, forehead, eye_region, nose, lips,
+ right_eye_region, left_eye_region.
+ """
+
+ self.v = BufferContainer()
+ for k, v_mask in part_masks.items():
+ self.v.register_buffer(k, torch.tensor(v_mask, dtype=torch.long))
+
+ self.create_custom_mask()
+
+ def create_custom_mask(self):
+ """Add some cutom masks based on the original FLAME masks"""
+
+ self.v.register_buffer("neck_left_point", torch.tensor([3193]))
+ self.v.register_buffer("neck_right_point", torch.tensor([3296]))
+ self.v.register_buffer("front_middle_bottom_point_boundary", torch.tensor([3285]))
+ self.v.register_buffer("back_middle_bottom_point_boundary", torch.tensor([3248]))
+
+ self.v.register_buffer(
+ "neck_top",
+ torch.tensor([
+ 10, 11, 111, 112, 784, 795, 1325, 1901, 2115, 2162, 2251, 2254, 2483, 2979, 3142, 3174, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_upper",
+ torch.tensor([
+ 1595, 1746, 1747, 1742, 1739, 1665, 1666, 3514, 2783, 2782, 2854, 2857, 2862, 2861, 2731
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_lower",
+ torch.tensor([
+ 1572, 1573, 1860, 1862, 1830, 1835, 1852, 3497, 2941, 2933, 2930, 2945, 2943, 2709, 2708
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_upper",
+ torch.tensor([
+ 1713, 1715, 1716, 1735, 1696, 1694, 1657, 3543, 2774, 2811, 2813, 2850, 2833, 2832, 2830
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_lower",
+ torch.tensor([
+ 1576, 1577, 1773, 1774, 1795, 1802, 1865, 3503, 2948, 2905, 2898, 2881, 2880, 2713, 2712
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_upper",
+ torch.tensor([
+ 1588, 1589, 1590, 1591, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1724, 1725, 1739, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 2724, 2725, 2726, 2727, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2841, 2842, 2854, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 3514, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_lower",
+ torch.tensor([
+ 1572, 1573, 1592, 1593, 1764, 1765, 1779, 1780, 1781, 1830, 1831, 1832, 1835, 1846, 1847, 1851, 1852, 1854, 1860, 1861, 1862, 2708, 2709, 2728, 2729, 2872, 2873, 2886, 2887, 2888, 2930, 2931, 2932, 2933, 2935, 2936, 2940, 2941, 2942, 2943, 2944, 2945, 3497, 3500, 3512,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside",
+ torch.tensor([
+ 1572, 1573, 1580, 1581, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1718, 1719, 1722, 1724, 1725, 1728, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1764, 1765, 1777, 1778, 1779, 1780, 1781, 1782, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1851, 1852, 1854, 1860, 1861, 1862, 2708, 2709, 2716, 2717, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2835, 2836, 2839, 2841, 2842, 2843, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2872, 2873, 2884, 2885, 2886, 2887, 2888, 2889, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2940, 2941, 2942, 2943, 2944, 2945, 3497, 3500, 3512, 3513, 3514, 3533, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_upper",
+ torch.tensor([
+ 10, 11, 12, 13, 14, 15, 111, 112, 219, 220, 221, 222, 372, 373, 374, 375, 462, 463, 496, 497, 552, 553, 558, 559, 563, 564, 649, 650, 736, 737, 784, 795, 1210, 1211, 1212, 1213, 1325, 1326, 1359, 1360, 1386, 1726, 1727, 1759, 1790, 1886, 1898, 1901, 1931, 1932, 1933, 1934, 1940, 1941, 1948, 1949, 2036, 2115, 2149, 2150, 2151, 2162, 2218, 2219, 2251, 2254, 2483, 2484, 2531, 2870, 2893, 2964, 2976, 2979, 3012, 3013, 3142, 3174, 3184, 3185, 3186, 3187, 3188, 3189, 3193, 3194, 3196, 3199, 3200, 3202, 3203, 3206, 3209, 3281, 3282, 3286, 3291, 3292, 3296, 3297, 3299, 3302, 3303, 3305, 3306, 3309, 3312, 3376, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460, 3461, 3462, 3463, 3494, 3496, 3544, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3695, 3697, 3698, 3701, 3703, 3707, 3709, 3713,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_lower",
+ torch.tensor([
+ 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3220, 3222, 3223, 3231, 3232, 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3250, 3251, 3253, 3254, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3286, 3288, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3323, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3352, 3353, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3376, 3378,
+ ])
+ )
+
+ # the bottomline of "neck"
+ self.v.register_buffer(
+ "neck_base",
+ torch.tensor([
+ 3231, 3232, 3237, 3238, 3240, 3242, 3243, 3251, 3263, 3290, 3332, 3333, 3338, 3339, 3341, 3343, 3344, 3350, 3362, # 4-th ring from bottom (drop 7 front verts)
+ ])
+ )
+
+ # As a subset of "boundary", "bottomline" only contains vertices on the edge
+ self.v.register_buffer(
+ "bottomline",
+ torch.tensor([
+ 3218, 3219, 3226, 3272, 3273, 3229, 3228, 3261, 3260, 3248, 3359, 3360, 3329, 3330, 3372, 3371, 3327, 3322, 3321, 3355, 3354, 3356, 3357, 3379, 3285, 3289, 3258, 3257, 3255, 3256
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_iris",
+ torch.tensor([
+ 3931, 3932, 3933, 3935, 3936, 3937, 3939, 3940, 3941, 3943, 3944, 3945, 3947, 3948, 3949, 3951, 3952, 3953, 3955, 3956, 3957, 3959, 3960, 3961, 3963, 3964, 3965, 3967, 3968, 3969, 3971, 3972, 3973, 3975, 3976, 3977, 3979, 3980, 3981, 3983, 3984, 3985, 3987, 3988, 3989, 3991, 3992, 3993, 3995, 3996, 3997, 3999, 4000, 4001, 4003, 4004, 4005, 4007, 4008, 4009, 4011, 4012, 4013, 4015, 4016, 4017, 4019, 4020, 4021, 4023, 4024, 4025, 4027, 4028, 4029, 4031, 4032, 4033, 4035, 4036, 4037, 4039, 4040, 4041, 4043, 4044, 4045, 4047, 4048, 4049, 4051, 4052, 4053, 4054, 4056, 4057, 4058,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_iris",
+ torch.tensor([
+ 4477, 4478, 4479, 4481, 4482, 4483, 4485, 4486, 4487, 4489, 4490, 4491, 4493, 4494, 4495, 4497, 4498, 4499, 4501, 4502, 4503, 4505, 4506, 4507, 4509, 4510, 4511, 4513, 4514, 4515, 4517, 4518, 4519, 4521, 4522, 4523, 4525, 4526, 4527, 4529, 4530, 4531, 4533, 4534, 4535, 4537, 4538, 4539, 4541, 4542, 4543, 4545, 4546, 4547, 4549, 4550, 4551, 4553, 4554, 4555, 4557, 4558, 4559, 4561, 4562, 4563, 4565, 4566, 4567, 4569, 4570, 4571, 4573, 4574, 4575, 4577, 4578, 4579, 4581, 4582, 4583, 4585, 4586, 4587, 4589, 4590, 4591, 4593, 4594, 4595, 4597, 4598, 4599, 4600, 4602, 4603, 4604,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_eyelid", # 30 vertices
+ torch.tensor([
+ 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 841, 842, 848, 864, 865, 877, 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 903, 904, 905, 922, 923, 924, 926, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 958, 959, 991, 992, 993, 994, 995, 999, 1000, 1003, 1006, 1008, 1011, 1023, 1033, 1034, 1045, 1046, 1059, 1060, 1061, 1062, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1125, 1126, 1132, 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1170, 1175, 1182, 1183, 1194, 1195, 1200, 1201, 1202, 1216, 1217, 1218, 1224, 1227, 1230, 1232, 1233, 1243, 1244, 1283, 1289, 1292, 1293, 1294, 1320, 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 3827, 3832, 3833, 3835, 3853, 3855, 3856, 3861,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_eyelid", # 30 vertices
+ torch.tensor([
+ 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2282, 2283, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2303, 2304, 2305, 2312, 2313, 2314, 2315, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2364, 2365, 2367, 2369, 2381, 2382, 2383, 2386, 2387, 2388, 2389, 2390, 2391, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2411, 2412, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2436, 2437, 2440, 2441, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2457, 2460, 2461, 2462, 2465, 2466, 2467, 2470, 2471, 2472, 2473, 2478, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 3619, 3631, 3632, 3638, 3687, 3689, 3690, 3700,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lips_tight", # 30 vertices
+ torch.tensor([
+ 1572, 1573, 1578, 1580, 1581, 1582, 1583, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1750, 1751, 1758, 1764, 1765, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1787, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1802, 1803, 1804, 1826, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860, 1861, 1862, 1865, 2708, 2709, 2714, 2716, 2717, 2718, 2719, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2865, 2866, 2869, 2872, 2873, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2894, 2895, 2896, 2897, 2898, 2905, 2906, 2907, 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2948, 3497, 3500, 3503, 3504, 3506, 3509, 3512, 3513, 3514, 3531, 3533, 3546, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_half",
+ torch.tensor([
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 530, 531, 532, 533, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 588, 589, 590, 591, 592, 593, 594, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 638, 639, 644, 645, 646, 647, 648, 649, 650, 667, 668, 669, 670, 671, 672, 673, 674, 679, 680, 681, 682, 683, 688, 691, 692, 693, 694, 695, 696, 697, 702, 703, 704, 705, 706, 707, 708, 709, 712, 713, 714, 715, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 745, 746, 747, 748, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 783, 784, 785, 786, 795, 796, 797, 798, 799, 802, 803, 804, 805, 806, 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 837, 838, 840, 841, 842, 846, 847, 848, 864, 865, 877, 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 898, 899, 902, 903, 904, 905, 906, 907, 908, 909, 918, 919, 922, 923, 924, 926, 927, 928, 929, 939, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 977, 978, 979, 980, 985, 986, 991, 992, 993, 994, 995, 999, 1000, 1001, 1002, 1003, 1006, 1007, 1008, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1033, 1034, 1043, 1044, 1045, 1046, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1068, 1075, 1085, 1086, 1087, 1088, 1092, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1116, 1117, 1125, 1126, 1127, 1128, 1129, 1132, 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1155, 1161, 1162, 1163, 1164, 1168, 1169, 1170, 1175, 1176, 1181, 1182, 1183, 1184, 1189, 1190, 1193, 1194, 1195, 1200, 1201, 1202, 1216, 1217, 1218, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1232, 1233, 1241, 1242, 1243, 1244, 1283, 1284, 1287, 1289, 1292, 1293, 1294, 1298, 1299, 1308, 1309, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1617, 1618, 1623, 1624, 1625, 1626, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1756, 1757, 1758, 1759, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1823, 1824, 1825, 1826, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1914, 1915, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1938, 1939, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2004, 2009, 2010, 2011, 2012, 2021, 2022, 2023, 2024, 2025, 2026, 2029, 2030, 2033, 2034, 2035, 2036, 2037, 2038, 2039, 2040, 2041, 2042, 2043, 2044, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056, 2057, 2058, 2059, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067, 2068, 2069, 2070, 2071, 2072, 2073, 2074, 2075, 2076, 2077, 2078, 2079, 2080, 2081, 2082, 2083, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2099, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108, 2109, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118, 2119, 2120, 2121, 2122, 2125, 2126, 2127, 2134, 2135, 2136, 2137, 2138, 2139, 2140, 2141, 2142, 2143, 2148, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158, 2159, 2160, 2161, 2162, 2163, 2164, 2169, 2170, 2171, 2172, 2173, 2174, 2175, 3186, 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3216, 3217, 3218, 3219, 3220, 3221, 3222, 3223, 3224, 3225, 3226, 3227, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3248, 3249, 3250, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3261, 3262, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3271, 3272, 3273, 3274, 3275, 3276, 3277, 3278, 3279, 3280, 3281, 3282, 3283, 3284, 3285, 3286, 3287, 3288, 3289, 3290, 3399, 3400, 3401, 3404, 3414, 3442, 3457, 3459, 3461, 3463, 3487, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581, 3582, 3583, 3584, 3587, 3588, 3593, 3594, 3595, 3596, 3598, 3599, 3600, 3601, 3604, 3605, 3611, 3614, 3623, 3624, 3625, 3626, 3628, 3629, 3630, 3634, 3635, 3636, 3637, 3643, 3644, 3646, 3649, 3650, 3652, 3653, 3654, 3655, 3656, 3658, 3659, 3660, 3662, 3663, 3664, 3665, 3666, 3667, 3668, 3670, 3671, 3672, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3691, 3693, 3695, 3697, 3698, 3701, 3703, 3704, 3707, 3709, 3713, 3714, 3715, 3716, 3717, 3722, 3724, 3725, 3726, 3727, 3728, 3730, 3734, 3737, 3738, 3739, 3740, 3742, 3745, 3752, 3753, 3754, 3756, 3757, 3760, 3761, 3762, 3769, 3771, 3772, 3785, 3786, 3790, 3801, 3807, 3808, 3809, 3810, 3811, 3812, 3813, 3814, 3815, 3816, 3817, 3818, 3819, 3820, 3821, 3822, 3823, 3824, 3825, 3826, 3827, 3828, 3829, 3830, 3831, 3832, 3833, 3834, 3835, 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 3844, 3845, 3846, 3847, 3848, 3849, 3850, 3851, 3852, 3853, 3854, 3855, 3856, 3857, 3858, 3859, 3860, 3861, 3862, 3863, 3864, 3865, 3866, 3867, 3868, 3869, 3870, 3871, 3872, 3873, 3874, 3875, 3876, 3877, 3878, 3879, 3880, 3881, 3882, 3883, 3884, 3885, 3886, 3887, 3888, 3889, 3890, 3891, 3892, 3893, 3894, 3895, 3896, 3897, 3898, 3899, 3900, 3901, 3902, 3903, 3904, 3905, 3906, 3907, 3908, 3909, 3910, 3911, 3912, 3913, 3914, 3915, 3916, 3917, 3918, 3919, 3920, 3921, 3922, 3923, 3924, 3925, 3926, 3927, 3928, 3929, 3931, 3932, 3933, 3934, 3935, 3936, 3937, 3938, 3939, 3940, 3941, 3942, 3943, 3944, 3945, 3946, 3947, 3948, 3949, 3950, 3951, 3952, 3953, 3954, 3955, 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3965, 3966, 3967, 3968, 3969, 3970, 3971, 3972, 3973, 3974, 3975, 3976, 3977, 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987, 3988, 3989, 3990, 3991, 3992, 3993, 3994, 3995, 3996, 3997, 3998, 3999, 4000, 4001, 4002, 4003, 4004, 4005, 4006, 4007, 4008, 4009, 4010, 4011, 4012, 4013, 4014, 4015, 4016, 4017, 4018, 4019, 4020, 4021, 4022, 4023, 4024, 4025, 4026, 4027, 4028, 4029, 4030, 4031, 4032, 4033, 4034, 4035, 4036, 4037, 4038, 4039, 4040, 4041, 4042, 4043, 4044, 4045, 4046, 4047, 4048, 4049, 4050, 4051, 4052, 4053, 4054, 4055, 4056, 4057, 4058, 4059, 4060, 4061, 4062, 4063, 4064, 4065, 4066, 4067, 4068, 4069, 4070, 4071, 4072, 4073, 4074, 4075, 4076, 4077, 4078, 4079, 4080, 4081, 4082, 4083, 4084, 4085, 4086, 4087, 4088, 4089, 4090, 4091, 4092, 4093, 4094, 4095, 4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103, 4104, 4105, 4106, 4107, 4108, 4109, 4110, 4111, 4112, 4113, 4114, 4115, 4116, 4117, 4118, 4119, 4120, 4121, 4122, 4123, 4124, 4125, 4126, 4127, 4128, 4129, 4130, 4131, 4132, 4133, 4134, 4135, 4136, 4137, 4138, 4139, 4140, 4141, 4142, 4143, 4144, 4145, 4146, 4147, 4148, 4149, 4150, 4151, 4152, 4153, 4154, 4155, 4156, 4157, 4158, 4159, 4160, 4161, 4162, 4163, 4164, 4165, 4166, 4167, 4168, 4169, 4170, 4171, 4172, 4173, 4174, 4175, 4176, 4177, 4178, 4179, 4180, 4181, 4182, 4183, 4184, 4185, 4186, 4187, 4188, 4189, 4190, 4191, 4192, 4193, 4194, 4195, 4196, 4197, 4198, 4199, 4200, 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4210, 4211, 4212, 4213, 4214, 4215, 4216, 4217, 4218, 4219, 4220, 4221, 4222, 4223, 4224, 4225, 4226, 4227, 4228, 4229, 4230, 4231, 4232, 4233, 4234, 4235, 4236, 4237, 4238, 4239, 4240, 4241, 4242, 4243, 4244, 4245, 4246, 4247, 4248, 4249, 4250, 4251, 4252, 4253, 4254, 4255, 4256, 4257, 4258, 4259, 4260, 4261, 4262, 4263, 4264, 4265, 4266, 4267, 4268, 4269, 4270, 4271, 4272, 4273, 4274, 4275, 4276, 4277, 4278, 4279, 4280, 4281, 4282, 4283, 4284, 4285, 4286, 4287, 4288, 4289, 4290, 4291, 4292, 4293, 4294, 4295, 4296, 4297, 4298, 4299, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4312, 4313, 4314, 4315, 4316, 4317, 4318, 4319, 4320, 4321, 4322, 4323, 4324, 4325, 4326, 4327, 4328, 4329, 4330, 4331, 4332, 4333, 4334, 4335, 4336, 4337, 4338, 4339, 4340, 4341, 4342, 4343, 4344, 4345, 4346, 4347, 4348, 4349, 4350, 4351, 4352, 4353, 4354, 4355, 4356, 4357, 4358, 4359, 4360, 4361, 4362, 4363, 4364, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4377, 4378, 4379, 4380, 4381, 4382, 4383, 4384, 4385, 4386, 4387, 4388, 4389, 4390, 4391, 4392, 4393, 4394, 4395, 4396, 4397, 4398, 4399, 4400, 4401, 4402, 4403, 4404, 4405, 4406, 4407, 4408, 4409, 4410, 4411, 4412, 4413, 4414, 4415, 4416, 4417, 4418, 4419, 4420, 4421, 4422, 4423, 4424, 4425, 4426, 4427, 4428, 4429, 4430, 4431, 4432, 4433, 4434, 4435, 4436, 4437, 4438, 4439, 4440, 4441, 4442, 4443, 4444, 4445, 4446, 4447, 4448, 4449, 4450, 4451, 4452, 4453, 4454, 4455, 4456, 4457, 4458, 4459, 4460, 4461, 4462, 4463, 4464, 4465, 4466, 4467, 4468, 4469, 4470, 4471, 4472, 4473, 4474, 4475, 4476,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_half",
+ torch.tensor([
+ 19, 20, 21, 22, 23, 24, 25, 26, 109, 110, 111, 112, 219, 220, 221, 222, 335, 336, 337, 338, 522, 523, 524, 525, 526, 527, 528, 529, 534, 535, 536, 537, 554, 555, 556, 557, 584, 585, 586, 587, 595, 596, 597, 598, 599, 600, 601, 602, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 634, 635, 636, 637, 640, 641, 642, 643, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 675, 676, 677, 678, 684, 685, 686, 687, 689, 690, 698, 699, 700, 701, 710, 711, 716, 717, 718, 719, 720, 721, 722, 741, 742, 743, 744, 749, 750, 751, 752, 776, 777, 778, 779, 780, 781, 782, 787, 788, 789, 790, 791, 792, 793, 794, 800, 801, 810, 811, 812, 813, 817, 818, 819, 820, 830, 831, 832, 833, 834, 835, 836, 839, 843, 844, 845, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 900, 901, 910, 911, 912, 913, 914, 915, 916, 917, 920, 921, 925, 930, 931, 932, 933, 934, 935, 936, 937, 938, 940, 941, 956, 957, 973, 974, 975, 976, 981, 982, 983, 984, 987, 988, 989, 990, 996, 997, 998, 1004, 1005, 1009, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1066, 1067, 1069, 1070, 1071, 1072, 1073, 1074, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1089, 1090, 1091, 1094, 1095, 1097, 1098, 1099, 1100, 1102, 1103, 1104, 1105, 1106, 1107, 1109, 1110, 1111, 1112, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1130, 1131, 1133, 1136, 1137, 1138, 1139, 1140, 1141, 1145, 1148, 1149, 1156, 1157, 1158, 1159, 1160, 1165, 1166, 1167, 1171, 1172, 1173, 1174, 1177, 1178, 1179, 1180, 1185, 1186, 1187, 1188, 1191, 1192, 1196, 1197, 1198, 1199, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1219, 1220, 1221, 1222, 1223, 1231, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1285, 1286, 1288, 1290, 1291, 1295, 1296, 1297, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1327, 1328, 1330, 1332, 1333, 1334, 1335, 1359, 1360, 1379, 1380, 1381, 1382, 1392, 1393, 1394, 1395, 1406, 1407, 1408, 1409, 1488, 1613, 1614, 1615, 1616, 1619, 1620, 1621, 1622, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1726, 1727, 1752, 1753, 1754, 1755, 1760, 1761, 1762, 1772, 1783, 1784, 1785, 1786, 1822, 1828, 1829, 1833, 1834, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1853, 1855, 1856, 1857, 1858, 1859, 1870, 1882, 1883, 1884, 1885, 1912, 1913, 1916, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1940, 1941, 1960, 1961, 1962, 1963, 1982, 1983, 1984, 1985, 2000, 2001, 2002, 2003, 2005, 2006, 2007, 2008, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2027, 2028, 2031, 2032, 2036, 2084, 2085, 2086, 2087, 2088, 2089, 2090, 2091, 2123, 2124, 2128, 2129, 2130, 2131, 2132, 2133, 2144, 2145, 2146, 2147, 2149, 2150, 2151, 2165, 2166, 2167, 2168, 2176, 2177, 2178, 2179, 2180, 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2197, 2198, 2199, 2200, 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2230, 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247, 2248, 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259, 2260, 2261, 2262, 2263, 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2375, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2455, 2456, 2457, 2458, 2459, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536, 2537, 2538, 2539, 2540, 2541, 2542, 2543, 2544, 2545, 2546, 2547, 2548, 2549, 2550, 2551, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2560, 2561, 2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583, 2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594, 2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604, 2605, 2606, 2607, 2608, 2609, 2610, 2611, 2612, 2613, 2614, 2615, 2616, 2617, 2618, 2619, 2620, 2621, 2622, 2623, 2624, 2625, 2626, 2627, 2628, 2629, 2630, 2631, 2632, 2633, 2634, 2635, 2636, 2637, 2638, 2639, 2640, 2641, 2642, 2643, 2644, 2645, 2646, 2647, 2648, 2649, 2650, 2651, 2652, 2653, 2654, 2655, 2656, 2657, 2658, 2659, 2660, 2661, 2662, 2663, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2671, 2672, 2673, 2674, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2683, 2684, 2685, 2686, 2687, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2707, 2708, 2709, 2710, 2711, 2712, 2713, 2714, 2715, 2716, 2717, 2718, 2719, 2720, 2721, 2722, 2723, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740, 2741, 2742, 2743, 2744, 2745, 2746, 2747, 2748, 2749, 2750, 2751, 2752, 2753, 2754, 2755, 2756, 2757, 2758, 2759, 2760, 2761, 2762, 2763, 2764, 2765, 2766, 2767, 2768, 2769, 2770, 2771, 2772, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2788, 2789, 2790, 2791, 2792, 2793, 2794, 2795, 2796, 2797, 2798, 2799, 2800, 2801, 2802, 2803, 2804, 2805, 2806, 2807, 2808, 2809, 2810, 2811, 2812, 2813, 2814, 2815, 2816, 2817, 2818, 2819, 2820, 2821, 2822, 2823, 2824, 2825, 2826, 2827, 2828, 2829, 2830, 2831, 2832, 2833, 2834, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2850, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2864, 2865, 2866, 2867, 2868, 2869, 2870, 2871, 2872, 2873, 2874, 2875, 2876, 2877, 2878, 2879, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2893, 2894, 2895, 2896, 2897, 2898, 2899, 2900, 2901, 2902, 2903, 2904, 2905, 2906, 2907, 2908, 2909, 2910, 2911, 2912, 2913, 2914, 2915, 2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2925, 2926, 2927, 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2946, 2947, 2948, 2949, 2950, 2951, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2961, 2962, 2963, 2964, 2965, 2966, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2974, 2975, 2976, 2977, 2978, 2979, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987, 2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995, 2996, 2997, 2998, 2999, 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012, 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029, 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063, 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097, 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107, 3108, 3109, 3110, 3111, 3112, 3113, 3114, 3115, 3116, 3117, 3118, 3119, 3120, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3131, 3132, 3133, 3134, 3135, 3136, 3137, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3145, 3146, 3147, 3148, 3149, 3150, 3151, 3152, 3153, 3154, 3155, 3156, 3157, 3158, 3159, 3160, 3161, 3162, 3163, 3164, 3165, 3166, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3175, 3176, 3177, 3178, 3179, 3180, 3181, 3182, 3183, 3184, 3185, 3222, 3223, 3248, 3249, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3284, 3285, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3319, 3320, 3321, 3322, 3323, 3324, 3325, 3326, 3327, 3328, 3329, 3330, 3331, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3351, 3352, 3353, 3354, 3355, 3356, 3357, 3358, 3359, 3360, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3370, 3371, 3372, 3373, 3374, 3375, 3376, 3377, 3378, 3379, 3380, 3381, 3382, 3383, 3384, 3385, 3386, 3387, 3388, 3389, 3390, 3391, 3392, 3393, 3394, 3395, 3396, 3397, 3398, 3399, 3400, 3401, 3402, 3403, 3404, 3405, 3406, 3407, 3408, 3409, 3410, 3411, 3412, 3413, 3414, 3415, 3416, 3417, 3418, 3419, 3420, 3421, 3422, 3423, 3424, 3425, 3426, 3427, 3428, 3429, 3430, 3431, 3432, 3433, 3434, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3450, 3451, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460, 3461, 3462, 3463, 3464, 3465, 3466, 3467, 3468, 3469, 3470, 3471, 3472, 3473, 3474, 3475, 3476, 3477, 3478, 3479, 3480, 3481, 3482, 3483, 3484, 3485, 3486, 3487, 3488, 3489, 3490, 3491, 3492, 3493, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3585, 3586, 3589, 3590, 3591, 3592, 3597, 3602, 3603, 3606, 3607, 3608, 3609, 3610, 3612, 3613, 3615, 3616, 3617, 3618, 3619, 3620, 3621, 3622, 3627, 3631, 3632, 3633, 3638, 3639, 3640, 3641, 3642, 3645, 3647, 3648, 3651, 3657, 3661, 3668, 3669, 3674, 3675, 3682, 3683, 3684, 3686, 3687, 3688, 3689, 3690, 3692, 3694, 3696, 3699, 3700, 3702, 3704, 3705, 3706, 3708, 3710, 3711, 3712, 3718, 3719, 3720, 3721, 3723, 3729, 3731, 3732, 3733, 3735, 3736, 3741, 3743, 3744, 3746, 3747, 3748, 3749, 3750, 3751, 3755, 3758, 3759, 3763, 3764, 3765, 3766, 3767, 3768, 3770, 3773, 3774, 3775, 3776, 3777, 3778, 3779, 3780, 3781, 3782, 3783, 3784, 3785, 3786, 3787, 3788, 3789, 3790, 3791, 3792, 3793, 3794, 3795, 3796, 3797, 3798, 3799, 3800, 3801, 3802, 3803, 3804, 3805, 3806, 3930, 4477, 4478, 4479, 4480, 4481, 4482, 4483, 4484, 4485, 4486, 4487, 4488, 4489, 4490, 4491, 4492, 4493, 4494, 4495, 4496, 4497, 4498, 4499, 4500, 4501, 4502, 4503, 4504, 4505, 4506, 4507, 4508, 4509, 4510, 4511, 4512, 4513, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4521, 4522, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4540, 4541, 4542, 4543, 4544, 4545, 4546, 4547, 4548, 4549, 4550, 4551, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4568, 4569, 4570, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4585, 4586, 4587, 4588, 4589, 4590, 4591, 4592, 4593, 4594, 4595, 4596, 4597, 4598, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606, 4607, 4608, 4609, 4610, 4611, 4612, 4613, 4614, 4615, 4616, 4617, 4618, 4619, 4620, 4621, 4622, 4623, 4624, 4625, 4626, 4627, 4628, 4629, 4630, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4639, 4640, 4641, 4642, 4643, 4644, 4645, 4646, 4647, 4648, 4649, 4650, 4651, 4652, 4653, 4654, 4655, 4656, 4657, 4658, 4659, 4660, 4661, 4662, 4663, 4664, 4665, 4666, 4667, 4668, 4669, 4670, 4671, 4672, 4673, 4674, 4675, 4676, 4677, 4678, 4679, 4680, 4681, 4682, 4683, 4684, 4685, 4686, 4687, 4688, 4689, 4690, 4691, 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4699, 4700, 4701, 4702, 4703, 4704, 4705, 4706, 4707, 4708, 4709, 4710, 4711, 4712, 4713, 4714, 4715, 4716, 4717, 4718, 4719, 4720, 4721, 4722, 4723, 4724, 4725, 4726, 4727, 4728, 4729, 4730, 4731, 4732, 4733, 4734, 4735, 4736, 4737, 4738, 4739, 4740, 4741, 4742, 4743, 4744, 4745, 4746, 4747, 4748, 4749, 4750, 4751, 4752, 4753, 4754, 4755, 4756, 4757, 4758, 4759, 4760, 4761, 4762, 4763, 4764, 4765, 4766, 4767, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4792, 4793, 4794, 4795, 4796, 4797, 4798, 4799, 4800, 4801, 4802, 4803, 4804, 4805, 4806, 4807, 4808, 4809, 4810, 4811, 4812, 4813, 4814, 4815, 4816, 4817, 4818, 4819, 4820, 4821, 4822, 4823, 4824, 4825, 4826, 4827, 4828, 4829, 4830, 4831, 4832, 4833, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4843, 4844, 4845, 4846, 4847, 4848, 4849, 4850, 4851, 4852, 4853, 4854, 4855, 4856, 4857, 4858, 4859, 4860, 4861, 4862, 4863, 4864, 4865, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4876, 4877, 4878, 4879, 4880, 4881, 4882, 4883, 4884, 4885, 4886, 4887, 4888, 4889, 4890, 4891, 4892, 4893, 4894, 4895, 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913, 4914, 4915, 4916, 4917, 4918, 4919, 4920, 4921, 4922, 4923, 4924, 4925, 4926, 4927, 4928, 4929, 4930, 4931, 4932, 4933, 4934, 4935, 4936, 4937, 4938, 4939, 4940, 4941, 4942, 4943, 4944, 4945, 4946, 4947, 4948, 4949, 4950, 4951, 4952, 4953, 4954, 4955, 4956, 4957, 4958, 4959, 4960, 4961, 4962, 4963, 4964, 4965, 4966, 4967, 4968, 4969, 4970, 4971, 4972, 4973, 4974, 4975, 4976, 4977, 4978, 4979, 4980, 4981, 4982, 4983, 4984, 4985, 4986, 4987, 4988, 4989, 4990, 4991, 4992, 4993, 4994, 4995, 4996, 4997, 4998, 4999, 5000, 5001, 5002, 5003, 5004, 5005, 5006, 5007, 5008, 5009, 5010, 5011, 5012, 5013, 5014, 5015, 5016, 5017, 5018, 5019, 5020, 5021, 5022
+ ])
+ )
+
+ # remove the intersection with neck from scalp and get the region for hair
+ face_and_neck = torch.cat([self.v.face, self.v.neck]).unique()
+ # get the intersection between scalp and face_and_neck
+ uniques, counts = torch.cat([self.v.scalp, face_and_neck]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.scalp, intersection]).unique(return_counts=True)
+ hair = uniques[counts == 1]
+ self.v.register_buffer("hair", hair)
+
+ # unions
+ self.v.register_buffer("ears", torch.cat([self.v.right_ear, self.v.left_ear]))
+ self.v.register_buffer("eyeballs", torch.cat([self.v.right_eyeball, self.v.left_eyeball]))
+ self.v.register_buffer("irises", torch.cat([self.v.right_iris, self.v.left_iris]))
+ self.v.register_buffer("left_eye", torch.cat([self.v.left_eye_region, self.v.left_eyeball]))
+ self.v.register_buffer("right_eye", torch.cat([self.v.right_eye_region, self.v.right_eyeball]))
+ self.v.register_buffer("eyelids", torch.cat([self.v.left_eyelid, self.v.right_eyelid]))
+ self.v.register_buffer("lip_inside_ring", torch.cat([self.v.lip_inside_ring_upper, self.v.lip_inside_ring_lower, torch.tensor([1594, 2730])]))
+
+ # remove the intersection with irises from eyeballs and get the region for scleras
+ uniques, counts = torch.cat([self.v.eyeballs, self.v.irises]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.eyeballs, intersection]).unique(return_counts=True)
+ sclerae = uniques[counts == 1]
+ self.v.register_buffer("sclerae", sclerae)
+
+ # skin
+ skin_except = ["eyeballs", "hair", "lips_tight", "boundary"]
+ if self.num_verts == 5083:
+ skin_except.append("teeth")
+ skin = self.get_vid_except_region(skin_except)
+ self.v.register_buffer("skin", skin)
+
+ def construct_vid_table(self):
+ self.vid_to_region = defaultdict(list) # vertex id -> region name
+ for region_name, v_mask in self.v:
+ for v_id in v_mask:
+ self.vid_to_region[v_id.item()].append(region_name)
+
+ def process_face_mask(self, faces):
+
+ face_masks = defaultdict(list) # region name -> face id
+ for f_id, f in enumerate(faces):
+ counters = defaultdict(int)
+ for v_id in f:
+ for region_name in self.vid_to_region[v_id.item()]:
+ counters[region_name] += 1
+
+ for region_name, count in counters.items():
+ if count >= 3: # create straight boundaries, with seams
+ # if count > 1: # create zigzag boundaries, no seams
+ face_masks[region_name].append(f_id)
+
+ self.f = BufferContainer()
+ for region_name, f_mask in face_masks.items():
+ self.f.register_buffer(region_name, torch.tensor(f_mask, dtype=torch.long))
+
+ def process_face_clusters(self, face_clusters):
+ """ Construct a lookup table from face id to cluster id.
+
+ cluster #0: background
+ cluster #1: foreground
+ cluster #2: faces in face_clusters[0]
+ cluster #3: faces in face_clusters[1]
+ ...
+ """
+ fid2cid = torch.ones(self.num_faces+1, dtype=torch.long) # faces are always treated as foreground
+ for cid, cluster in enumerate(face_clusters):
+ try:
+ fids = self.get_fid_by_region([cluster])
+ except Exception as e:
+ continue
+ fid2cid[fids] = cid + 2 # reserve cluster #0 for the background and #1 for faces that do not belong to any cluster
+ self.register_buffer("fid2cid", fid2cid)
+
+ def process_vt_mask(self, faces, faces_t):
+ vt_masks = defaultdict(list) # region name -> vt id
+ for f_id, (face, face_t) in enumerate(zip(faces, faces_t)):
+ for v_id, vt_id in zip(face, face_t):
+ for region_name in self.vid_to_region[v_id.item()]:
+ vt_masks[region_name].append(vt_id.item())
+
+ self.vt = BufferContainer()
+ for region_name, vt_mask in vt_masks.items():
+ self.vt.register_buffer(region_name, torch.tensor(vt_mask, dtype=torch.long))
+
+ def get_vid_by_region(self, regions, keep_order=False):
+ """Get vertex indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ vid = torch.cat([self.v.get_buffer(k) for k in regions])
+ if keep_order:
+ return vid
+ else:
+ return vid.unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_vid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.v.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the vertex indicies that are not included by regions
+ vert_idx = torch.arange(0, self.num_verts, device=indices.device)
+ combined = torch.cat((indices, vert_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_by_region(self, regions):
+ """Get face indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ return torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_fid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the face indicies that are not included by regions
+ face_idx = torch.arange(0, self.num_faces, device=indices.device)
+ combined = torch.cat((indices, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_except_fids(self, fids):
+ # get the face indicies that are not included
+ face_idx = torch.arange(0, self.num_faces, device=fids.device)
+ combined = torch.cat((fids, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+
+
+if __name__ == '__main__':
+ add_teeth = True
+ subdivide_num = 0
+ teeth_bs_flag = False
+ oral_mesh_flag = False
+ human_model_path = "./pretrained_models/human_model_files"
+ flame_model = FlameHeadSubdivided(
+ 300,
+ 100,
+ add_teeth=add_teeth,
+ add_shoulder=False,
+ flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl',
+ flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy",
+ flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj",
+ flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl",
+ subdivide_num=subdivide_num,
+ teeth_bs_flag=teeth_bs_flag,
+ oral_mesh_flag=oral_mesh_flag
+ )
diff --git a/lam/models/rendering/flame_model/flame_arkit.py b/lam/models/rendering/flame_model/flame_arkit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9da2483b2053279f5970b7b0b3814c941ccd0b4e
--- /dev/null
+++ b/lam/models/rendering/flame_model/flame_arkit.py
@@ -0,0 +1,1815 @@
+# Code heavily inspired by https://github.com/HavenFeng/photometric_optimization/blob/master/models/FLAME.py.
+# Please consider citing their work if you find this code useful. The code is subject to the license available via
+# https://github.com/vchoutas/flame/edit/master/LICENSE
+import os.path
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+
+from .lbs import lbs, vertices2landmarks, blend_shapes, vertices2joints
+from .lbs import batch_rigid_transform, batch_rodrigues
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pickle
+from collections import defaultdict
+
+try:
+ from pytorch3d.io import load_obj
+except ImportError:
+ from utils.pytorch3d_load_obj import load_obj
+
+from pytorch3d.structures import Meshes
+from pytorch3d.ops import SubdivideMeshes
+
+
+# FLAME_MESH_PATH = "flame_model/assets/flame/head_template_mesh.obj"
+# FLAME_LMK_PATH = "flame_model/assets/flame/landmark_embedding_with_eyes.npy"
+
+# # to be downloaded from https://flame.is.tue.mpg.de/download.php
+# # FLAME_MODEL_PATH = "flame_model/assets/flame/generic_model.pkl" # FLAME 2020
+# FLAME_MODEL_PATH = "flame_model/assets/flame/flame2023.pkl" # FLAME 2023 (versions w/ jaw rotation)
+# FLAME_PARTS_PATH = "flame_model/assets/flame/FLAME_masks.pkl" # FLAME Vertex Masks
+
+def to_tensor(array, dtype=torch.float32):
+ if "torch.tensor" not in str(type(array)):
+ return torch.tensor(array, dtype=dtype)
+
+
+def to_np(array, dtype=np.float32):
+ if "scipy.sparse" in str(type(array)):
+ array = array.todense()
+ return np.array(array, dtype=dtype)
+
+
+class Struct(object):
+ def __init__(self, **kwargs):
+ for key, val in kwargs.items():
+ setattr(self, key, val)
+
+
+def face_vertices(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of faces, 3, 3]
+ """
+ assert vertices.ndimension() == 3
+ assert faces.ndimension() == 3
+ assert vertices.shape[0] == faces.shape[0]
+ assert vertices.shape[2] == 3
+ assert faces.shape[2] == 3
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
+
+class FlameHead(nn.Module):
+ """
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+
+ def __init__(
+ self,
+ shape_params,
+ expr_params,
+ flame_model_path=None,
+ flame_lmk_embedding_path=None,
+ flame_template_mesh_path=None,
+ flame_parts_path=None,
+ include_mask=True,
+ add_teeth=True,
+ add_shoulder=False,
+ flame_arkit_bs_path=None
+ ):
+ super().__init__()
+
+ self.n_shape_params = shape_params
+ self.n_expr_params = expr_params
+ assert expr_params != 52, "The dimension of the ARKIT expression must be equal to 52."
+
+ with open(flame_model_path, "rb") as f:
+ ss = pickle.load(f, encoding="latin1")
+ flame_model = Struct(**ss)
+
+ self.dtype = torch.float32
+ # The vertices of the template model
+ self.register_buffer(
+ "v_template", to_tensor(to_np(flame_model.v_template), dtype=self.dtype)
+ )
+
+ # The shape components and expression
+ shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
+
+ # load arkit bs
+ assert os.path.exists(flame_arkit_bs_path)
+
+ flame_arkit_bs = np.load(flame_arkit_bs_path).astype(np.float32)
+ flame_arkit_bs = torch.from_numpy(flame_arkit_bs).float().permute(1, 2, 0)
+
+ shapedirs = torch.cat(
+ [shapedirs[:, :, :shape_params], flame_arkit_bs],
+ 2,
+ )
+ self.register_buffer("shapedirs", shapedirs)
+
+ # The pose components
+ num_pose_basis = flame_model.posedirs.shape[-1]
+ posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
+ self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
+ #
+ self.register_buffer(
+ "J_regressor", to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)
+ )
+ parents = to_tensor(to_np(flame_model.kintree_table[0])).long()
+ parents[0] = -1
+ self.register_buffer("parents", parents)
+ self.register_buffer(
+ "lbs_weights", to_tensor(to_np(flame_model.weights), dtype=self.dtype)
+ )
+
+ # Landmark embeddings for FLAME
+ lmk_embeddings = np.load(
+ flame_lmk_embedding_path, allow_pickle=True, encoding="latin1"
+ )
+ lmk_embeddings = lmk_embeddings[()]
+ self.register_buffer(
+ "full_lmk_faces_idx",
+ torch.tensor(lmk_embeddings["full_lmk_faces_idx"], dtype=torch.long),
+ )
+ self.register_buffer(
+ "full_lmk_bary_coords",
+ torch.tensor(lmk_embeddings["full_lmk_bary_coords"], dtype=self.dtype),
+ )
+
+ neck_kin_chain = []
+ NECK_IDX = 1
+ curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
+ while curr_idx != -1:
+ neck_kin_chain.append(curr_idx)
+ curr_idx = self.parents[curr_idx]
+ self.register_buffer("neck_kin_chain", torch.stack(neck_kin_chain))
+
+ # add faces and uvs
+ verts, faces, aux = load_obj(flame_template_mesh_path, load_textures=False)
+
+ vertex_uvs = aux.verts_uvs
+ face_uvs_idx = faces.textures_idx # index into verts_uvs
+
+ # create uvcoords per face --> this is what you can use for uv map rendering
+ # range from -1 to 1 (-1, -1) = left top; (+1, +1) = right bottom
+ # pad 1 to the end
+ pad = torch.ones(vertex_uvs.shape[0], 1)
+ vertex_uvs = torch.cat([vertex_uvs, pad], dim=-1)
+ vertex_uvs = vertex_uvs * 2 - 1
+ vertex_uvs[..., 1] = -vertex_uvs[..., 1]
+
+ face_uv_coords = face_vertices(vertex_uvs[None], face_uvs_idx[None])[0]
+ self.register_buffer("face_uvcoords", face_uv_coords, persistent=False)
+ self.register_buffer("faces", faces.verts_idx, persistent=False)
+
+ self.register_buffer("verts_uvs", aux.verts_uvs, persistent=False)
+ self.register_buffer("textures_idx", faces.textures_idx, persistent=False)
+ # Check our template mesh faces match those of FLAME:
+ assert (self.faces == torch.from_numpy(flame_model.f.astype('int64'))).all()
+ if include_mask:
+ self.mask = FlameMask(
+ flame_parts_path=flame_parts_path,
+ faces=self.faces,
+ faces_t=self.textures_idx,
+ num_verts=self.v_template.shape[0],
+ num_faces=self.faces.shape[0],
+ )
+
+ if add_teeth:
+ self.add_teeth()
+
+ self.add_shoulder = add_shoulder
+ if (add_shoulder):
+ import trimesh
+ shoulder_mesh = trimesh.load('flame_model/assets/shoulder_mesh.obj')
+ self.v_shoulder = torch.tensor(shoulder_mesh.vertices).float()
+ self.f_shoulder = torch.tensor(shoulder_mesh.faces) + self.v_template.shape[0]
+
+ self.v_template = torch.cat([self.v_template, self.v_shoulder], dim=0)
+ self.faces = torch.cat([self.faces, self.f_shoulder])
+
+ # num_verts_shoulder = shoulder_v.shape[0]
+ # self.v_template = torch.cat([self.v_template, shoulder_v], dim=0)
+ #
+ # shapedirs_shoulder = torch.zeros((num_verts_shoulder,3,400)).float()
+ # self.shapedirs = torch.concat([self.shapedirs,shapedirs_shoulder],dim=0)
+ #
+ # # posedirs set to zero
+ # posedirs = self.posedirs.reshape(len(self.parents) - 1, 9, num_verts_orig, 3) # (J*9, V*3) -> (J, 9, V, 3)
+ # posedirs = torch.cat([posedirs, torch.zeros_like(posedirs[:, :, :num_verts_shoulder])],dim=2) # (J, 9, V+num_verts_teeth, 3)
+ # self.posedirs = posedirs.reshape((len(self.parents) - 1) * 9, (num_verts_orig + num_verts_shoulder) * 3) # (J*9, (V+num_verts_teeth)*3)
+ #
+ # # J_regressor set to zero
+ # self.J_regressor = torch.cat([self.J_regressor, torch.zeros_like(self.J_regressor[:, :num_verts_shoulder])], dim=1) # (5, J) -> (5, J+num_verts_teeth)
+ #
+ # # lbs_weights manually set
+ # self.lbs_weights = torch.cat([self.lbs_weights, torch.zeros_like(self.lbs_weights[:num_verts_shoulder])],dim=0) # (V, 5) -> (V+num_verts_teeth, 5)
+ #
+ #
+ # self.lbs_weights[vid_teeth_upper, 1] += 1 # move with neck
+ # self.lbs_weights[vid_teeth_lower, 2] += 1 # move with jaw
+ #
+ # self.faces = torch.cat([self.faces, f_teeth_upper + num_verts_orig, f_teeth_lower + num_verts_orig], dim=0)
+ # self.textures_idx = torch.cat(
+ # [self.textures_idx, f_teeth_upper + num_verts_uv_orig, f_teeth_lower + num_verts_uv_orig], dim=0)
+ #
+ # self.mask.update(self.faces, self.textures_idx)
+
+ # import trimesh
+ # mesh = trimesh.Trimesh()
+ # mesh.vertices = to_np(self.v_template)
+ # mesh.faces = to_np(self.faces, dtype=np.int64)
+ # mesh.export('/home/yuanzhen/flame_2023_w_shoulder.obj')
+ # exit()
+
+ def add_teeth(self):
+ # get reference vertices from lips
+ vid_lip_outside_ring_upper = self.mask.get_vid_by_region(['lip_outside_ring_upper'], keep_order=True)
+
+ vid_lip_outside_ring_lower = self.mask.get_vid_by_region(['lip_outside_ring_lower'], keep_order=True)
+
+ v_lip_upper = self.v_template[vid_lip_outside_ring_upper]
+ v_lip_lower = self.v_template[vid_lip_outside_ring_lower]
+
+ # construct vertices for teeth
+ mean_dist = (v_lip_upper - v_lip_lower).norm(dim=-1, keepdim=True).mean()
+ v_teeth_middle = (v_lip_upper + v_lip_lower) / 2
+ v_teeth_middle[:, 1] = v_teeth_middle[:, [1]].mean(dim=0, keepdim=True)
+ # v_teeth_middle[:, 2] -= mean_dist * 2.5 # how far the teeth are from the lips
+ # v_teeth_middle[:, 2] -= mean_dist * 2 # how far the teeth are from the lips
+ v_teeth_middle[:, 2] -= mean_dist * 1.5 # how far the teeth are from the lips
+
+ # upper, front
+ v_teeth_upper_edge = v_teeth_middle.clone() + torch.tensor([[0, mean_dist, 0]]) * 0.1
+ v_teeth_upper_root = v_teeth_upper_edge + torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # lower, front
+ v_teeth_lower_edge = v_teeth_middle.clone() - torch.tensor([[0, mean_dist, 0]]) * 0.1
+ # v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.2 # slightly move the lower teeth to the back
+ v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.4 # slightly move the lower teeth to the back
+ v_teeth_lower_root = v_teeth_lower_edge - torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # thickness = mean_dist * 0.5
+ thickness = mean_dist * 1.
+ # upper, back
+ v_teeth_upper_root_back = v_teeth_upper_root.clone()
+ v_teeth_upper_edge_back = v_teeth_upper_edge.clone()
+ v_teeth_upper_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_upper_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # lower, back
+ v_teeth_lower_root_back = v_teeth_lower_root.clone()
+ v_teeth_lower_edge_back = v_teeth_lower_edge.clone()
+ v_teeth_lower_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_lower_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # concatenate to v_template
+ num_verts_orig = self.v_template.shape[0]
+ v_teeth = torch.cat([
+ v_teeth_upper_root, # num_verts_orig + 0-14
+ v_teeth_lower_root, # num_verts_orig + 15-29
+ v_teeth_upper_edge, # num_verts_orig + 30-44
+ v_teeth_lower_edge, # num_verts_orig + 45-59
+ v_teeth_upper_root_back, # num_verts_orig + 60-74
+ v_teeth_upper_edge_back, # num_verts_orig + 75-89
+ v_teeth_lower_root_back, # num_verts_orig + 90-104
+ v_teeth_lower_edge_back, # num_verts_orig + 105-119
+ ], dim=0)
+ num_verts_teeth = v_teeth.shape[0]
+ self.v_template = torch.cat([self.v_template, v_teeth], dim=0)
+
+ vid_teeth_upper_root = torch.arange(0, 15) + num_verts_orig
+ vid_teeth_lower_root = torch.arange(15, 30) + num_verts_orig
+ vid_teeth_upper_edge = torch.arange(30, 45) + num_verts_orig
+ vid_teeth_lower_edge = torch.arange(45, 60) + num_verts_orig
+ vid_teeth_upper_root_back = torch.arange(60, 75) + num_verts_orig
+ vid_teeth_upper_edge_back = torch.arange(75, 90) + num_verts_orig
+ vid_teeth_lower_root_back = torch.arange(90, 105) + num_verts_orig
+ vid_teeth_lower_edge_back = torch.arange(105, 120) + num_verts_orig
+
+ vid_teeth_upper = torch.cat(
+ [vid_teeth_upper_root, vid_teeth_upper_edge, vid_teeth_upper_root_back, vid_teeth_upper_edge_back], dim=0)
+ vid_teeth_lower = torch.cat(
+ [vid_teeth_lower_root, vid_teeth_lower_edge, vid_teeth_lower_root_back, vid_teeth_lower_edge_back], dim=0)
+ vid_teeth = torch.cat([vid_teeth_upper, vid_teeth_lower], dim=0)
+
+ # update vertex masks
+ self.mask.v.register_buffer("teeth_upper", vid_teeth_upper)
+ self.mask.v.register_buffer("teeth_lower", vid_teeth_lower)
+ self.mask.v.register_buffer("teeth", vid_teeth)
+ self.mask.v.left_half = torch.cat([
+ self.mask.v.left_half,
+ torch.tensor([
+ 5023, 5024, 5025, 5026, 5027, 5028, 5029, 5030, 5038, 5039, 5040, 5041, 5042, 5043, 5044, 5045, 5053,
+ 5054, 5055, 5056, 5057, 5058, 5059, 5060, 5068, 5069, 5070, 5071, 5072, 5073, 5074, 5075, 5083, 5084,
+ 5085, 5086, 5087, 5088, 5089, 5090, 5098, 5099, 5100, 5101, 5102, 5103, 5104, 5105, 5113, 5114, 5115,
+ 5116, 5117, 5118, 5119, 5120, 5128, 5129, 5130, 5131, 5132, 5133, 5134, 5135,
+ ])], dim=0)
+
+ self.mask.v.right_half = torch.cat([
+ self.mask.v.right_half,
+ torch.tensor([
+ 5030, 5031, 5032, 5033, 5034, 5035, 5036, 5037, 5045, 5046, 5047, 5048, 5049, 5050, 5051, 5052, 5060,
+ 5061, 5062, 5063, 5064, 5065, 5066, 5067, 5075, 5076, 5077, 5078, 5079, 5080, 5081, 5082, 5090, 5091,
+ 5092, 5093, 5094, 5095, 5097, 5105, 5106, 5107, 5108, 5109, 5110, 5111, 5112, 5120, 5121, 5122, 5123,
+ 5124, 5125, 5126, 5127, 5135, 5136, 5137, 5138, 5139, 5140, 5141, 5142,
+ ])], dim=0)
+
+ # construct uv vertices for teeth
+ u = torch.linspace(0.62, 0.38, 15)
+ v = torch.linspace(1 - 0.0083, 1 - 0.0425, 7)
+ # v = v[[0, 2, 1, 1]]
+ # v = v[[0, 3, 1, 4, 3, 2, 6, 5]]
+ v = v[[3, 2, 0, 1, 3, 4, 6, 5]] # TODO: with this order, teeth_lower is not rendered correctly in the uv space
+ uv = torch.stack(torch.meshgrid(u, v, indexing='ij'), dim=-1).permute(1, 0, 2).reshape(num_verts_teeth,
+ 2) # (#num_teeth, 2)
+ num_verts_uv_orig = self.verts_uvs.shape[0]
+ num_verts_uv_teeth = uv.shape[0]
+ self.verts_uvs = torch.cat([self.verts_uvs, uv], dim=0)
+
+ # shapedirs copy from lips
+ self.shapedirs = torch.cat([self.shapedirs, torch.zeros_like(self.shapedirs[:num_verts_teeth])], dim=0)
+ shape_dirs_mean = (self.shapedirs[vid_lip_outside_ring_upper, :, :self.n_shape_params] + self.shapedirs[
+ vid_lip_outside_ring_lower,
+ :,
+ :self.n_shape_params]) / 2
+ self.shapedirs[vid_teeth_upper_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+
+ # posedirs set to zero
+ posedirs = self.posedirs.reshape(len(self.parents) - 1, 9, num_verts_orig, 3) # (J*9, V*3) -> (J, 9, V, 3)
+ posedirs = torch.cat([posedirs, torch.zeros_like(posedirs[:, :, :num_verts_teeth])],
+ dim=2) # (J, 9, V+num_verts_teeth, 3)
+ self.posedirs = posedirs.reshape((len(self.parents) - 1) * 9,
+ (num_verts_orig + num_verts_teeth) * 3) # (J*9, (V+num_verts_teeth)*3)
+
+ # J_regressor set to zero
+ self.J_regressor = torch.cat([self.J_regressor, torch.zeros_like(self.J_regressor[:, :num_verts_teeth])],
+ dim=1) # (5, J) -> (5, J+num_verts_teeth)
+
+ # lbs_weights manually set
+ self.lbs_weights = torch.cat([self.lbs_weights, torch.zeros_like(self.lbs_weights[:num_verts_teeth])],
+ dim=0) # (V, 5) -> (V+num_verts_teeth, 5)
+ self.lbs_weights[vid_teeth_upper, 1] += 1 # move with neck
+ self.lbs_weights[vid_teeth_lower, 2] += 1 # move with jaw
+
+ # add faces for teeth
+ f_teeth_upper = torch.tensor([
+ [0, 31, 30], # 0
+ [0, 1, 31], # 1
+ [1, 32, 31], # 2
+ [1, 2, 32], # 3
+ [2, 33, 32], # 4
+ [2, 3, 33], # 5
+ [3, 34, 33], # 6
+ [3, 4, 34], # 7
+ [4, 35, 34], # 8
+ [4, 5, 35], # 9
+ [5, 36, 35], # 10
+ [5, 6, 36], # 11
+ [6, 37, 36], # 12
+ [6, 7, 37], # 13
+ [7, 8, 37], # 14
+ [8, 38, 37], # 15
+ [8, 9, 38], # 16
+ [9, 39, 38], # 17
+ [9, 10, 39], # 18
+ [10, 40, 39], # 19
+ [10, 11, 40], # 20
+ [11, 41, 40], # 21
+ [11, 12, 41], # 22
+ [12, 42, 41], # 23
+ [12, 13, 42], # 24
+ [13, 43, 42], # 25
+ [13, 14, 43], # 26
+ [14, 44, 43], # 27
+ [60, 75, 76], # 56
+ [60, 76, 61], # 57
+ [61, 76, 77], # 58
+ [61, 77, 62], # 59
+ [62, 77, 78], # 60
+ [62, 78, 63], # 61
+ [63, 78, 79], # 62
+ [63, 79, 64], # 63
+ [64, 79, 80], # 64
+ [64, 80, 65], # 65
+ [65, 80, 81], # 66
+ [65, 81, 66], # 67
+ [66, 81, 82], # 68
+ [66, 82, 67], # 69
+ [67, 82, 68], # 70
+ [68, 82, 83], # 71
+ [68, 83, 69], # 72
+ [69, 83, 84], # 73
+ [69, 84, 70], # 74
+ [70, 84, 85], # 75
+ [70, 85, 71], # 76
+ [71, 85, 86], # 77
+ [71, 86, 72], # 78
+ [72, 86, 87], # 79
+ [72, 87, 73], # 80
+ [73, 87, 88], # 81
+ [73, 88, 74], # 82
+ [74, 88, 89], # 83
+ [75, 30, 76], # 84
+ [76, 30, 31], # 85
+ [76, 31, 77], # 86
+ [77, 31, 32], # 87
+ [77, 32, 78], # 88
+ [78, 32, 33], # 89
+ [78, 33, 79], # 90
+ [79, 33, 34], # 91
+ [79, 34, 80], # 92
+ [80, 34, 35], # 93
+ [80, 35, 81], # 94
+ [81, 35, 36], # 95
+ [81, 36, 82], # 96
+ [82, 36, 37], # 97
+ [82, 37, 38], # 98
+ [82, 38, 83], # 99
+ [83, 38, 39], # 100
+ [83, 39, 84], # 101
+ [84, 39, 40], # 102
+ [84, 40, 85], # 103
+ [85, 40, 41], # 104
+ [85, 41, 86], # 105
+ [86, 41, 42], # 106
+ [86, 42, 87], # 107
+ [87, 42, 43], # 108
+ [87, 43, 88], # 109
+ [88, 43, 44], # 110
+ [88, 44, 89], # 111
+ ])
+ f_teeth_lower = torch.tensor([
+ [45, 46, 15], # 28
+ [46, 16, 15], # 29
+ [46, 47, 16], # 30
+ [47, 17, 16], # 31
+ [47, 48, 17], # 32
+ [48, 18, 17], # 33
+ [48, 49, 18], # 34
+ [49, 19, 18], # 35
+ [49, 50, 19], # 36
+ [50, 20, 19], # 37
+ [50, 51, 20], # 38
+ [51, 21, 20], # 39
+ [51, 52, 21], # 40
+ [52, 22, 21], # 41
+ [52, 23, 22], # 42
+ [52, 53, 23], # 43
+ [53, 24, 23], # 44
+ [53, 54, 24], # 45
+ [54, 25, 24], # 46
+ [54, 55, 25], # 47
+ [55, 26, 25], # 48
+ [55, 56, 26], # 49
+ [56, 27, 26], # 50
+ [56, 57, 27], # 51
+ [57, 28, 27], # 52
+ [57, 58, 28], # 53
+ [58, 29, 28], # 54
+ [58, 59, 29], # 55
+ [90, 106, 105], # 112
+ [90, 91, 106], # 113
+ [91, 107, 106], # 114
+ [91, 92, 107], # 115
+ [92, 108, 107], # 116
+ [92, 93, 108], # 117
+ [93, 109, 108], # 118
+ [93, 94, 109], # 119
+ [94, 110, 109], # 120
+ [94, 95, 110], # 121
+ [95, 111, 110], # 122
+ [95, 96, 111], # 123
+ [96, 112, 111], # 124
+ [96, 97, 112], # 125
+ [97, 98, 112], # 126
+ [98, 113, 112], # 127
+ [98, 99, 113], # 128
+ [99, 114, 113], # 129
+ [99, 100, 114], # 130
+ [100, 115, 114], # 131
+ [100, 101, 115], # 132
+ [101, 116, 114], # 133
+ [101, 102, 116], # 134
+ [102, 117, 116], # 135
+ [102, 103, 117], # 136
+ [103, 118, 117], # 137
+ [103, 104, 118], # 138
+ [104, 119, 118], # 139
+ [105, 106, 45], # 140
+ [106, 46, 45], # 141
+ [106, 107, 46], # 142
+ [107, 47, 46], # 143
+ [107, 108, 47], # 144
+ [108, 48, 47], # 145
+ [108, 109, 48], # 146
+ [109, 49, 48], # 147
+ [109, 110, 49], # 148
+ [110, 50, 49], # 149
+ [110, 111, 50], # 150
+ [111, 51, 50], # 151
+ [111, 112, 51], # 152
+ [112, 52, 51], # 153
+ [112, 53, 52], # 154
+ [112, 113, 53], # 155
+ [113, 54, 53], # 156
+ [113, 114, 54], # 157
+ [114, 55, 54], # 158
+ [114, 115, 55], # 159
+ [115, 56, 55], # 160
+ [115, 116, 56], # 161
+ [116, 57, 56], # 162
+ [116, 117, 57], # 163
+ [117, 58, 57], # 164
+ [117, 118, 58], # 165
+ [118, 59, 58], # 166
+ [118, 119, 59], # 167
+ ])
+ self.faces = torch.cat([self.faces, f_teeth_upper + num_verts_orig, f_teeth_lower + num_verts_orig], dim=0)
+ self.textures_idx = torch.cat(
+ [self.textures_idx, f_teeth_upper + num_verts_uv_orig, f_teeth_lower + num_verts_uv_orig], dim=0)
+
+ self.mask.update(self.faces, self.textures_idx)
+
+ def forward(
+ self,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = shape.shape[0]
+
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if (self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0] - self.v_shoulder.shape[0])].unsqueeze(
+ 0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped_woexpr = template_vertices + blend_shapes(torch.cat([betas[:, :self.n_shape_params],
+ torch.zeros_like(betas[:, self.n_shape_params:])],
+ dim=1), self.shapedirs)
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # import trimesh
+ # mesh = trimesh.Trimesh()
+ # mesh.vertices = np.array(v_shaped.cpu().squeeze())
+ # mesh.faces = np.array(self.faces.cpu().squeeze())
+ # mesh.export('/media/yuanzhen/HH/offset_flame.obj')
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:, :(self.v_template.shape[0] - self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ # mesh.vertices = np.array(v_shaped.cpu().squeeze())
+ # mesh.export('/media/yuanzhen/HH/cano_flame.obj')
+ # exit()
+
+ vertices, J, mat_rot = lbs(
+ full_pose,
+ v_shaped,
+ self.posedirs,
+ self.J_regressor,
+ self.parents,
+ self.lbs_weights,
+ dtype=self.dtype,
+ )
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] = vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = v_shaped_woexpr
+ ret_vals["cano_with_expr"] = v_shaped
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+
+class FlameHeadSubdivided(FlameHead):
+ """
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+
+ def __init__(
+ self,
+ shape_params,
+ expr_params,
+ flame_model_path=None,
+ flame_lmk_embedding_path=None,
+ flame_template_mesh_path=None,
+ flame_parts_path=None,
+ include_mask=True,
+ add_teeth=True,
+ add_shoulder=False,
+ subdivide_num=0,
+ flame_arkit_bs_path=None,
+ ):
+ super().__init__(shape_params=shape_params,
+ expr_params=expr_params,
+ flame_model_path=flame_model_path,
+ flame_lmk_embedding_path=flame_lmk_embedding_path,
+ flame_template_mesh_path=flame_template_mesh_path,
+ include_mask=include_mask,
+ add_teeth=add_teeth,
+ add_shoulder=add_shoulder,
+ flame_parts_path=flame_parts_path,
+ flame_arkit_bs_path=flame_arkit_bs_path
+ )
+
+ # subdivider
+ self.subdivide_num = subdivide_num
+ self.subdivider_list = self.get_subdivider(subdivide_num)
+ self.subdivider_cpu_list = self.get_subdivider_cpu(subdivide_num)
+ self.face_upsampled = self.subdivider_list[
+ -1]._subdivided_faces.cpu().numpy() if self.subdivide_num > 0 else self.faces.numpy()
+ self.vertex_num_upsampled = int(np.max(self.face_upsampled) + 1)
+
+ self.vertex_num = self.v_template.shape[0]
+ self.joint_num = self.J_regressor.shape[0]
+ print(f"face_upsampled:{self.face_upsampled.shape}, face_ori:{self.faces.shape}, \
+ vertex_num_upsampled:{self.vertex_num_upsampled}, vertex_num_ori:{self.vertex_num}")
+
+ lbs_weights = self.lbs_weights.float()
+ posedirs = self.posedirs.permute(1, 0).reshape(self.vertex_num, 3 * (self.joint_num - 1) * 9)
+ # expr_dirs = self.expr_dirs.view(self.vertex_num, 3 * self.n_expr_params)
+ shapedirs = self.shapedirs.view(self.vertex_num, 3 * (self.n_shape_params + self.n_expr_params))
+ J_regressor = self.J_regressor.permute(1, 0)
+
+ v_template_upsampled, lbs_weights, posedirs, shapedirs, J_regressor = \
+ self.upsample_mesh_cpu(self.v_template.float(),
+ [lbs_weights,
+ posedirs,
+ shapedirs,
+ J_regressor,
+ ],
+ ) # upsample with dummy vertex
+
+ posedirs = posedirs.reshape(self.vertex_num_upsampled * 3, (self.joint_num - 1) * 9).permute(1, 0)
+ shapedirs = shapedirs.view(self.vertex_num_upsampled, 3, (self.n_shape_params + self.n_expr_params))
+ J_regressor = J_regressor.permute(1, 0)
+
+ self.register_buffer('faces', torch.from_numpy(self.face_upsampled))
+ self.register_buffer('v_template_up', v_template_upsampled.contiguous())
+ self.register_buffer('lbs_weights_up', lbs_weights.contiguous())
+ # self.register_buffer('posedirs', posedirs.contiguous())
+ self.register_buffer('shapedirs_up', shapedirs.contiguous())
+ # self.register_buffer('J_regressor', J_regressor.contiguous())
+
+ def get_cano_verts(self, shape_params):
+ # TODO check
+ assert self.add_shoulder == False
+ batch_size = shape_params.shape[0]
+
+ template_vertices = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1)
+
+ v_shaped = template_vertices + blend_shapes(shape_params, self.shapedirs_up[:, :, :self.n_shape_params])
+
+ return v_shaped
+
+ def animation_forward(self,
+ v_cano,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ assert self.add_shoulder == False
+ assert static_offset is None
+
+ batch_size = shape.shape[0]
+
+ # step1. get animated_joint and corresponding transformed mat (Note not in upsampled space)
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if (self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0] - self.v_shoulder.shape[0])].unsqueeze(
+ 0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:, :(self.v_template.shape[0] - self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ A, J = self.get_transformed_mat(pose=full_pose, v_shaped=v_shaped, posedirs=self.posedirs,
+ parents=self.parents, J_regressor=self.J_regressor, pose2rot=True,
+ dtype=self.dtype)
+
+ # step2. v_cano_with_expr
+ v_cano_with_expr = v_cano + blend_shapes(expr, self.shapedirs_up[:, :, self.n_shape_params:])
+
+ # step3. lbs
+ vertices = self.skinning(v_posed=v_cano_with_expr, A=A, lbs_weights=self.lbs_weights_up, batch_size=batch_size,
+ num_joints=self.joint_num, dtype=self.dtype, device=full_pose.device)
+
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] = vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = v_cano
+ ret_vals["cano_with_expr"] = v_cano_with_expr
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+ def get_transformed_mat(self, pose, v_shaped, posedirs, parents, J_regressor, pose2rot, dtype):
+ batch_size = pose.shape[0]
+ device = pose.device
+
+ # Get the joints
+ # NxJx3 array
+ J = vertices2joints(J_regressor, v_shaped)
+
+ # 3. Add pose blend shapes
+ # N x J x 3 x 3
+ ident = torch.eye(3, dtype=dtype, device=device)
+ if pose2rot:
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view(
+ [batch_size, -1, 3, 3]
+ )
+
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
+ # (N x P) x (P, V * 3) -> N x V x 3
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
+ else:
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
+ rot_mats = pose.view(batch_size, -1, 3, 3)
+
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(
+ batch_size, -1, 3
+ )
+
+ v_posed = pose_offsets + v_shaped
+
+ # 4. Get the global joint location
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
+
+ return A, J_transformed
+
+ def skinning(self, v_posed, A, lbs_weights, batch_size, num_joints, dtype, device):
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ # num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones(
+ [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
+ )
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+ verts = v_homo[:, :, :3, 0]
+
+ return verts
+
+ def forward(
+ self,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = shape.shape[0]
+
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+
+ if (self.add_shoulder):
+ template_vertices = self.v_template[:(self.v_template.shape[0] - self.v_shoulder.shape[0])].unsqueeze(
+ 0).expand(batch_size, -1, -1)
+ else:
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped_woexpr = template_vertices + blend_shapes(betas[:, :self.n_shape_params],
+ self.shapedirs[:, :, :self.n_shape_params])
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ if (self.add_shoulder):
+ v_shaped += static_offset[:, :(self.v_template.shape[0] - self.v_shoulder.shape[0])]
+ else:
+ v_shaped += static_offset
+
+ A, J = self.get_transformed_mat(pose=full_pose, v_shaped=v_shaped, posedirs=self.posedirs,
+ parents=self.parents, J_regressor=self.J_regressor, pose2rot=True,
+ dtype=self.dtype)
+
+ v_shaped_up = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1) + blend_shapes(betas,
+ self.shapedirs_up)
+ vertices = self.skinning(v_posed=v_shaped_up, A=A, lbs_weights=self.lbs_weights_up, batch_size=batch_size,
+ num_joints=self.joint_num, dtype=self.dtype, device=full_pose.device)
+
+ if (self.add_shoulder):
+ v_shaped = torch.cat([v_shaped,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+ vertices = torch.cat([vertices,
+ self.v_template[(self.v_template.shape[0] - self.v_shoulder.shape[0]):].unsqueeze(
+ 0).expand(batch_size, -1, -1)], dim=1)
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = {}
+ ret_vals["animated"] = vertices
+
+ if return_verts_cano:
+ ret_vals["cano"] = self.v_template_up.unsqueeze(0).expand(batch_size, -1, -1) + blend_shapes(
+ betas[:, :self.n_shape_params], self.shapedirs_up[:, :, :self.n_shape_params])
+ ret_vals["cano_with_expr"] = v_shaped_up
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals["landmarks"] = landmarks
+
+ return ret_vals
+
+ def get_subdivider(self, subdivide_num):
+ vert = self.v_template.float().cuda()
+ face = torch.LongTensor(self.faces).cuda()
+ mesh = Meshes(vert[None, :, :], face[None, :, :])
+
+ if subdivide_num > 0:
+ subdivider_list = [SubdivideMeshes(mesh)]
+ for i in range(subdivide_num - 1):
+ mesh = subdivider_list[-1](mesh)
+ subdivider_list.append(SubdivideMeshes(mesh))
+ else:
+ subdivider_list = [mesh]
+ return subdivider_list
+
+ def get_subdivider_cpu(self, subdivide_num):
+ vert = self.v_template.float()
+ face = torch.LongTensor(self.faces)
+ mesh = Meshes(vert[None, :, :], face[None, :, :])
+
+ if subdivide_num > 0:
+ subdivider_list = [SubdivideMeshes(mesh)]
+ for i in range(subdivide_num - 1):
+ mesh = subdivider_list[-1](mesh)
+ subdivider_list.append(SubdivideMeshes(mesh))
+ else:
+ subdivider_list = [mesh]
+ return subdivider_list
+
+ def upsample_mesh_cpu(self, vert, feat_list=None):
+ face = torch.LongTensor(self.faces)
+ mesh = Meshes(vert[None, :, :], face[None, :, :])
+ if self.subdivide_num > 0:
+ if feat_list is None:
+ for subdivider in self.subdivider_cpu_list:
+ mesh = subdivider(mesh)
+ vert = mesh.verts_list()[0]
+ return vert
+ else:
+ feat_dims = [x.shape[1] for x in feat_list]
+ feats = torch.cat(feat_list, 1)
+ for subdivider in self.subdivider_cpu_list:
+ mesh, feats = subdivider(mesh, feats)
+ vert = mesh.verts_list()[0]
+ feats = feats[0]
+ feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+ else:
+ if feat_list is None:
+ # for subdivider in self.subdivider_cpu_list:
+ # mesh = subdivider(mesh)
+ # vert = mesh.verts_list()[0]
+ return vert
+ else:
+ # feat_dims = [x.shape[1] for x in feat_list]
+ # feats = torch.cat(feat_list,1)
+ # for subdivider in self.subdivider_cpu_list:
+ # mesh, feats = subdivider(mesh, feats)
+ # vert = mesh.verts_list()[0]
+ # feats = feats[0]
+ # feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+
+ def upsample_mesh(self, vert, feat_list=None, device="cuda"):
+ face = torch.LongTensor(self.faces).to(device)
+ mesh = Meshes(vert[None, :, :], face[None, :, :])
+ if self.subdivide_num > 0:
+ if feat_list is None:
+ for subdivider in self.subdivider_list:
+ mesh = subdivider(mesh)
+ vert = mesh.verts_list()[0]
+ return vert
+ else:
+ feat_dims = [x.shape[1] for x in feat_list]
+ feats = torch.cat(feat_list, 1)
+ for subdivider in self.subdivider_list:
+ mesh, feats = subdivider(mesh, feats)
+ vert = mesh.verts_list()[0]
+ feats = feats[0]
+ feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+ else:
+ if feat_list is None:
+ # for subdivider in self.subdivider_list:
+ # mesh = subdivider(mesh)
+ # vert = mesh.verts_list()[0]
+ return vert
+ else:
+ # feat_dims = [x.shape[1] for x in feat_list]
+ # feats = torch.cat(feat_list,1)
+ # for subdivider in self.subdivider_list:
+ # mesh, feats = subdivider(mesh, feats)
+ # vert = mesh.verts_list()[0]
+ # feats = feats[0]
+ # feat_list = torch.split(feats, feat_dims, dim=1)
+ return vert, *feat_list
+
+ def upsample_mesh_batch(self, vert, device="cuda"):
+ if self.subdivide_num > 0:
+ face = torch.LongTensor(self.faces).to(device).unsqueeze(0).repeat(vert.shape[0], 1, 1)
+ mesh = Meshes(vert, face)
+ for subdivider in self.subdivider_list:
+ mesh = subdivider(mesh)
+ vert = torch.stack(mesh.verts_list(), dim=0)
+ else:
+ pass
+ return vert
+
+
+class BufferContainer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def __repr__(self):
+ main_str = super().__repr__() + '\n'
+ for name, buf in self.named_buffers():
+ main_str += f' {name:20}\t{buf.shape}\t{buf.dtype}\n'
+ return main_str
+
+ def __iter__(self):
+ for name, buf in self.named_buffers():
+ yield name, buf
+
+ def keys(self):
+ return [name for name, buf in self.named_buffers()]
+
+ def items(self):
+ return [(name, buf) for name, buf in self.named_buffers()]
+
+
+class FlameMask(nn.Module):
+ def __init__(
+ self,
+ flame_parts_path=None,
+ faces=None,
+ faces_t=None,
+ num_verts=5023,
+ num_faces=9976,
+ face_clusters=[],
+ ):
+ super().__init__()
+ self.faces = faces
+ self.faces_t = faces_t
+ self.face_clusters = face_clusters
+ self.num_verts = num_verts
+ if faces is not None:
+ self.num_faces = faces.shape[0]
+ else:
+ self.num_faces = num_faces
+
+ self.process_vertex_mask(flame_parts_path)
+
+ if self.faces is not None:
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def update(self, faces=None, faces_t=None, face_clusters=None):
+ """Update the faces properties when vertex masks are changed"""
+ if faces is not None:
+ self.faces = faces
+ self.num_faces = faces.shape[0]
+ if faces_t is not None:
+ self.faces_t = faces_t
+ if face_clusters is not None:
+ self.face_clusters = face_clusters
+
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def process_vertex_mask(self, flame_parts_path):
+ """Load the vertex masks from the FLAME model and add custom masks"""
+
+ part_masks = np.load(flame_parts_path, allow_pickle=True, encoding="latin1")
+ """ Available part masks from the FLAME model:
+ face, neck, scalp, boundary, right_eyeball, left_eyeball,
+ right_ear, left_ear, forehead, eye_region, nose, lips,
+ right_eye_region, left_eye_region.
+ """
+
+ self.v = BufferContainer()
+ for k, v_mask in part_masks.items():
+ self.v.register_buffer(k, torch.tensor(v_mask, dtype=torch.long))
+
+ self.create_custom_mask()
+
+ def create_custom_mask(self):
+ """Add some cutom masks based on the original FLAME masks"""
+
+ self.v.register_buffer("neck_left_point", torch.tensor([3193]))
+ self.v.register_buffer("neck_right_point", torch.tensor([3296]))
+ self.v.register_buffer("front_middle_bottom_point_boundary", torch.tensor([3285]))
+ self.v.register_buffer("back_middle_bottom_point_boundary", torch.tensor([3248]))
+
+ self.v.register_buffer(
+ "neck_top",
+ torch.tensor([
+ 10, 11, 111, 112, 784, 795, 1325, 1901, 2115, 2162, 2251, 2254, 2483, 2979, 3142, 3174, 3441, 3442,
+ 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_upper",
+ torch.tensor([
+ 1595, 1746, 1747, 1742, 1739, 1665, 1666, 3514, 2783, 2782, 2854, 2857, 2862, 2861, 2731
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_lower",
+ torch.tensor([
+ 1572, 1573, 1860, 1862, 1830, 1835, 1852, 3497, 2941, 2933, 2930, 2945, 2943, 2709, 2708
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_upper",
+ torch.tensor([
+ 1713, 1715, 1716, 1735, 1696, 1694, 1657, 3543, 2774, 2811, 2813, 2850, 2833, 2832, 2830
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_lower",
+ torch.tensor([
+ 1576, 1577, 1773, 1774, 1795, 1802, 1865, 3503, 2948, 2905, 2898, 2881, 2880, 2713, 2712
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_upper",
+ torch.tensor([
+ 1588, 1589, 1590, 1591, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1724, 1725, 1739,
+ 1741, 1742, 1743, 1744, 1745, 1746, 1747, 2724, 2725, 2726, 2727, 2730, 2731, 2776, 2777, 2778, 2779,
+ 2780, 2781, 2782, 2783, 2841, 2842, 2854, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 3514, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_lower",
+ torch.tensor([
+ 1572, 1573, 1592, 1593, 1764, 1765, 1779, 1780, 1781, 1830, 1831, 1832, 1835, 1846, 1847, 1851, 1852,
+ 1854, 1860, 1861, 1862, 2708, 2709, 2728, 2729, 2872, 2873, 2886, 2887, 2888, 2930, 2931, 2932, 2933,
+ 2935, 2936, 2940, 2941, 2942, 2943, 2944, 2945, 3497, 3500, 3512,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside",
+ torch.tensor([
+ 1572, 1573, 1580, 1581, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660, 1661, 1662, 1663,
+ 1664, 1665, 1666, 1667, 1668, 1718, 1719, 1722, 1724, 1725, 1728, 1739, 1740, 1741, 1742, 1743, 1744,
+ 1745, 1746, 1747, 1748, 1764, 1765, 1777, 1778, 1779, 1780, 1781, 1782, 1827, 1830, 1831, 1832, 1835,
+ 1836, 1846, 1847, 1851, 1852, 1854, 1860, 1861, 1862, 2708, 2709, 2716, 2717, 2724, 2725, 2726, 2727,
+ 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2835, 2836, 2839,
+ 2841, 2842, 2843, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2872, 2873, 2884, 2885,
+ 2886, 2887, 2888, 2889, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2940, 2941, 2942, 2943, 2944,
+ 2945, 3497, 3500, 3512, 3513, 3514, 3533, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_upper",
+ torch.tensor([
+ 10, 11, 12, 13, 14, 15, 111, 112, 219, 220, 221, 222, 372, 373, 374, 375, 462, 463, 496, 497, 552, 553,
+ 558, 559, 563, 564, 649, 650, 736, 737, 784, 795, 1210, 1211, 1212, 1213, 1325, 1326, 1359, 1360, 1386,
+ 1726, 1727, 1759, 1790, 1886, 1898, 1901, 1931, 1932, 1933, 1934, 1940, 1941, 1948, 1949, 2036, 2115,
+ 2149, 2150, 2151, 2162, 2218, 2219, 2251, 2254, 2483, 2484, 2531, 2870, 2893, 2964, 2976, 2979, 3012,
+ 3013, 3142, 3174, 3184, 3185, 3186, 3187, 3188, 3189, 3193, 3194, 3196, 3199, 3200, 3202, 3203, 3206,
+ 3209, 3281, 3282, 3286, 3291, 3292, 3296, 3297, 3299, 3302, 3303, 3305, 3306, 3309, 3312, 3376, 3441,
+ 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460,
+ 3461, 3462, 3463, 3494, 3496, 3544, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3695, 3697,
+ 3698, 3701, 3703, 3707, 3709, 3713,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_lower",
+ torch.tensor([
+ 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204,
+ 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3220, 3222, 3223, 3231, 3232, 3233,
+ 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3250, 3251, 3253,
+ 3254, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3286,
+ 3288, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305,
+ 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3323, 3332, 3333, 3334,
+ 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3352,
+ 3353, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3376, 3378,
+ ])
+ )
+
+ # the bottomline of "neck"
+ self.v.register_buffer(
+ "neck_base",
+ torch.tensor([
+ 3231, 3232, 3237, 3238, 3240, 3242, 3243, 3251, 3263, 3290, 3332, 3333, 3338, 3339, 3341, 3343, 3344,
+ 3350, 3362, # 4-th ring from bottom (drop 7 front verts)
+ ])
+ )
+
+ # As a subset of "boundary", "bottomline" only contains vertices on the edge
+ self.v.register_buffer(
+ "bottomline",
+ torch.tensor([
+ 3218, 3219, 3226, 3272, 3273, 3229, 3228, 3261, 3260, 3248, 3359, 3360, 3329, 3330, 3372, 3371, 3327,
+ 3322, 3321, 3355, 3354, 3356, 3357, 3379, 3285, 3289, 3258, 3257, 3255, 3256
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_iris",
+ torch.tensor([
+ 3931, 3932, 3933, 3935, 3936, 3937, 3939, 3940, 3941, 3943, 3944, 3945, 3947, 3948, 3949, 3951, 3952,
+ 3953, 3955, 3956, 3957, 3959, 3960, 3961, 3963, 3964, 3965, 3967, 3968, 3969, 3971, 3972, 3973, 3975,
+ 3976, 3977, 3979, 3980, 3981, 3983, 3984, 3985, 3987, 3988, 3989, 3991, 3992, 3993, 3995, 3996, 3997,
+ 3999, 4000, 4001, 4003, 4004, 4005, 4007, 4008, 4009, 4011, 4012, 4013, 4015, 4016, 4017, 4019, 4020,
+ 4021, 4023, 4024, 4025, 4027, 4028, 4029, 4031, 4032, 4033, 4035, 4036, 4037, 4039, 4040, 4041, 4043,
+ 4044, 4045, 4047, 4048, 4049, 4051, 4052, 4053, 4054, 4056, 4057, 4058,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_iris",
+ torch.tensor([
+ 4477, 4478, 4479, 4481, 4482, 4483, 4485, 4486, 4487, 4489, 4490, 4491, 4493, 4494, 4495, 4497, 4498,
+ 4499, 4501, 4502, 4503, 4505, 4506, 4507, 4509, 4510, 4511, 4513, 4514, 4515, 4517, 4518, 4519, 4521,
+ 4522, 4523, 4525, 4526, 4527, 4529, 4530, 4531, 4533, 4534, 4535, 4537, 4538, 4539, 4541, 4542, 4543,
+ 4545, 4546, 4547, 4549, 4550, 4551, 4553, 4554, 4555, 4557, 4558, 4559, 4561, 4562, 4563, 4565, 4566,
+ 4567, 4569, 4570, 4571, 4573, 4574, 4575, 4577, 4578, 4579, 4581, 4582, 4583, 4585, 4586, 4587, 4589,
+ 4590, 4591, 4593, 4594, 4595, 4597, 4598, 4599, 4600, 4602, 4603, 4604,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_eyelid", # 30 vertices
+ torch.tensor([
+ 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 841, 842, 848, 864, 865, 877,
+ 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 903, 904, 905, 922, 923, 924, 926, 945, 946, 947, 948,
+ 949, 950, 951, 952, 953, 954, 955, 958, 959, 991, 992, 993, 994, 995, 999, 1000, 1003, 1006, 1008, 1011,
+ 1023, 1033, 1034, 1045, 1046, 1059, 1060, 1061, 1062, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1125,
+ 1126, 1132, 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1170, 1175, 1182,
+ 1183, 1194, 1195, 1200, 1201, 1202, 1216, 1217, 1218, 1224, 1227, 1230, 1232, 1233, 1243, 1244, 1283,
+ 1289, 1292, 1293, 1294, 1320, 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345,
+ 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 3827, 3832, 3833, 3835, 3853, 3855, 3856, 3861,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_eyelid", # 30 vertices
+ torch.tensor([
+ 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2282, 2283,
+ 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2303, 2304, 2305,
+ 2312, 2313, 2314, 2315, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335,
+ 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2364, 2365, 2367, 2369, 2381, 2382, 2383, 2386, 2387, 2388,
+ 2389, 2390, 2391, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2411, 2412, 2416, 2417, 2418, 2419, 2420,
+ 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2436, 2437, 2440, 2441, 2446, 2447, 2448, 2449, 2450,
+ 2451, 2452, 2453, 2454, 2457, 2460, 2461, 2462, 2465, 2466, 2467, 2470, 2471, 2472, 2473, 2478, 2485,
+ 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2503, 2504, 2505, 2506, 2507, 2508,
+ 2509, 2510, 3619, 3631, 3632, 3638, 3687, 3689, 3690, 3700,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lips_tight", # 30 vertices
+ torch.tensor([
+ 1572, 1573, 1578, 1580, 1581, 1582, 1583, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660,
+ 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1718, 1719, 1720, 1721, 1722, 1723, 1724,
+ 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744,
+ 1745, 1746, 1747, 1748, 1750, 1751, 1758, 1764, 1765, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780,
+ 1781, 1782, 1787, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1802, 1803, 1804, 1826, 1827, 1830, 1831,
+ 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860, 1861, 1862, 1865, 2708, 2709,
+ 2714, 2716, 2717, 2718, 2719, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779,
+ 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843,
+ 2844, 2845, 2846, 2847, 2848, 2849, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861,
+ 2862, 2863, 2865, 2866, 2869, 2872, 2873, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889,
+ 2890, 2891, 2892, 2894, 2895, 2896, 2897, 2898, 2905, 2906, 2907, 2928, 2929, 2930, 2931, 2932, 2933,
+ 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2948, 3497, 3500, 3503, 3504,
+ 3506, 3509, 3512, 3513, 3514, 3531, 3533, 3546, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_half",
+ torch.tensor([
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 27, 28, 29, 30, 31, 32, 33, 34, 35,
+ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
+ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 113, 114,
+ 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135,
+ 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
+ 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177,
+ 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
+ 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 223,
+ 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244,
+ 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265,
+ 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286,
+ 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307,
+ 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328,
+ 329, 330, 331, 332, 333, 334, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
+ 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374,
+ 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395,
+ 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416,
+ 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437,
+ 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458,
+ 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479,
+ 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500,
+ 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521,
+ 530, 531, 532, 533, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 558,
+ 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579,
+ 580, 581, 582, 583, 588, 589, 590, 591, 592, 593, 594, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628,
+ 629, 630, 631, 632, 633, 638, 639, 644, 645, 646, 647, 648, 649, 650, 667, 668, 669, 670, 671, 672, 673,
+ 674, 679, 680, 681, 682, 683, 688, 691, 692, 693, 694, 695, 696, 697, 702, 703, 704, 705, 706, 707, 708,
+ 709, 712, 713, 714, 715, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738,
+ 739, 740, 745, 746, 747, 748, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767,
+ 768, 769, 770, 771, 772, 773, 774, 775, 783, 784, 785, 786, 795, 796, 797, 798, 799, 802, 803, 804, 805,
+ 806, 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 837, 838, 840, 841, 842,
+ 846, 847, 848, 864, 865, 877, 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 898, 899, 902, 903, 904,
+ 905, 906, 907, 908, 909, 918, 919, 922, 923, 924, 926, 927, 928, 929, 939, 942, 943, 944, 945, 946, 947,
+ 948, 949, 950, 951, 952, 953, 954, 955, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970,
+ 971, 972, 977, 978, 979, 980, 985, 986, 991, 992, 993, 994, 995, 999, 1000, 1001, 1002, 1003, 1006,
+ 1007, 1008, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1033,
+ 1034, 1043, 1044, 1045, 1046, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1068, 1075, 1085, 1086, 1087,
+ 1088, 1092, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1116, 1117, 1125, 1126, 1127, 1128, 1129, 1132,
+ 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1155, 1161, 1162, 1163, 1164,
+ 1168, 1169, 1170, 1175, 1176, 1181, 1182, 1183, 1184, 1189, 1190, 1193, 1194, 1195, 1200, 1201, 1202,
+ 1216, 1217, 1218, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1232, 1233, 1241, 1242, 1243, 1244, 1283,
+ 1284, 1287, 1289, 1292, 1293, 1294, 1298, 1299, 1308, 1309, 1320, 1321, 1322, 1323, 1324, 1325, 1326,
+ 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350,
+ 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369,
+ 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390,
+ 1391, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1410, 1411, 1412, 1413, 1414, 1415,
+ 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432,
+ 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449,
+ 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466,
+ 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483,
+ 1484, 1485, 1486, 1487, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501,
+ 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518,
+ 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535,
+ 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552,
+ 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569,
+ 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586,
+ 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603,
+ 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1617, 1618, 1623, 1624, 1625, 1626, 1638, 1639,
+ 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656,
+ 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673,
+ 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690,
+ 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707,
+ 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724,
+ 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743,
+ 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1756, 1757, 1758, 1759, 1763, 1764, 1765, 1766, 1767,
+ 1768, 1769, 1770, 1771, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1787, 1788, 1789,
+ 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806,
+ 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1823, 1824,
+ 1825, 1826, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860,
+ 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878,
+ 1879, 1880, 1881, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899,
+ 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1914, 1915, 1917, 1918, 1919,
+ 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1938, 1939, 1942, 1943, 1944, 1945, 1946, 1947,
+ 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1964, 1965, 1966, 1967, 1968,
+ 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1986, 1987, 1988, 1989,
+ 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2004, 2009, 2010, 2011, 2012, 2021, 2022,
+ 2023, 2024, 2025, 2026, 2029, 2030, 2033, 2034, 2035, 2036, 2037, 2038, 2039, 2040, 2041, 2042, 2043,
+ 2044, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056, 2057, 2058, 2059, 2060,
+ 2061, 2062, 2063, 2064, 2065, 2066, 2067, 2068, 2069, 2070, 2071, 2072, 2073, 2074, 2075, 2076, 2077,
+ 2078, 2079, 2080, 2081, 2082, 2083, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2099, 2100, 2101, 2102,
+ 2103, 2104, 2105, 2106, 2107, 2108, 2109, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118, 2119,
+ 2120, 2121, 2122, 2125, 2126, 2127, 2134, 2135, 2136, 2137, 2138, 2139, 2140, 2141, 2142, 2143, 2148,
+ 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158, 2159, 2160, 2161, 2162, 2163, 2164, 2169, 2170, 2171,
+ 2172, 2173, 2174, 2175, 3186, 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198,
+ 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215,
+ 3216, 3217, 3218, 3219, 3220, 3221, 3222, 3223, 3224, 3225, 3226, 3227, 3228, 3229, 3230, 3231, 3232,
+ 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3248, 3249,
+ 3250, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3261, 3262, 3263, 3264, 3265, 3266,
+ 3267, 3268, 3269, 3270, 3271, 3272, 3273, 3274, 3275, 3276, 3277, 3278, 3279, 3280, 3281, 3282, 3283,
+ 3284, 3285, 3286, 3287, 3288, 3289, 3290, 3399, 3400, 3401, 3404, 3414, 3442, 3457, 3459, 3461, 3463,
+ 3487, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509,
+ 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526,
+ 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543,
+ 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560,
+ 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577,
+ 3578, 3579, 3580, 3581, 3582, 3583, 3584, 3587, 3588, 3593, 3594, 3595, 3596, 3598, 3599, 3600, 3601,
+ 3604, 3605, 3611, 3614, 3623, 3624, 3625, 3626, 3628, 3629, 3630, 3634, 3635, 3636, 3637, 3643, 3644,
+ 3646, 3649, 3650, 3652, 3653, 3654, 3655, 3656, 3658, 3659, 3660, 3662, 3663, 3664, 3665, 3666, 3667,
+ 3668, 3670, 3671, 3672, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3691, 3693, 3695, 3697, 3698,
+ 3701, 3703, 3704, 3707, 3709, 3713, 3714, 3715, 3716, 3717, 3722, 3724, 3725, 3726, 3727, 3728, 3730,
+ 3734, 3737, 3738, 3739, 3740, 3742, 3745, 3752, 3753, 3754, 3756, 3757, 3760, 3761, 3762, 3769, 3771,
+ 3772, 3785, 3786, 3790, 3801, 3807, 3808, 3809, 3810, 3811, 3812, 3813, 3814, 3815, 3816, 3817, 3818,
+ 3819, 3820, 3821, 3822, 3823, 3824, 3825, 3826, 3827, 3828, 3829, 3830, 3831, 3832, 3833, 3834, 3835,
+ 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 3844, 3845, 3846, 3847, 3848, 3849, 3850, 3851, 3852,
+ 3853, 3854, 3855, 3856, 3857, 3858, 3859, 3860, 3861, 3862, 3863, 3864, 3865, 3866, 3867, 3868, 3869,
+ 3870, 3871, 3872, 3873, 3874, 3875, 3876, 3877, 3878, 3879, 3880, 3881, 3882, 3883, 3884, 3885, 3886,
+ 3887, 3888, 3889, 3890, 3891, 3892, 3893, 3894, 3895, 3896, 3897, 3898, 3899, 3900, 3901, 3902, 3903,
+ 3904, 3905, 3906, 3907, 3908, 3909, 3910, 3911, 3912, 3913, 3914, 3915, 3916, 3917, 3918, 3919, 3920,
+ 3921, 3922, 3923, 3924, 3925, 3926, 3927, 3928, 3929, 3931, 3932, 3933, 3934, 3935, 3936, 3937, 3938,
+ 3939, 3940, 3941, 3942, 3943, 3944, 3945, 3946, 3947, 3948, 3949, 3950, 3951, 3952, 3953, 3954, 3955,
+ 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3965, 3966, 3967, 3968, 3969, 3970, 3971, 3972,
+ 3973, 3974, 3975, 3976, 3977, 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987, 3988, 3989,
+ 3990, 3991, 3992, 3993, 3994, 3995, 3996, 3997, 3998, 3999, 4000, 4001, 4002, 4003, 4004, 4005, 4006,
+ 4007, 4008, 4009, 4010, 4011, 4012, 4013, 4014, 4015, 4016, 4017, 4018, 4019, 4020, 4021, 4022, 4023,
+ 4024, 4025, 4026, 4027, 4028, 4029, 4030, 4031, 4032, 4033, 4034, 4035, 4036, 4037, 4038, 4039, 4040,
+ 4041, 4042, 4043, 4044, 4045, 4046, 4047, 4048, 4049, 4050, 4051, 4052, 4053, 4054, 4055, 4056, 4057,
+ 4058, 4059, 4060, 4061, 4062, 4063, 4064, 4065, 4066, 4067, 4068, 4069, 4070, 4071, 4072, 4073, 4074,
+ 4075, 4076, 4077, 4078, 4079, 4080, 4081, 4082, 4083, 4084, 4085, 4086, 4087, 4088, 4089, 4090, 4091,
+ 4092, 4093, 4094, 4095, 4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103, 4104, 4105, 4106, 4107, 4108,
+ 4109, 4110, 4111, 4112, 4113, 4114, 4115, 4116, 4117, 4118, 4119, 4120, 4121, 4122, 4123, 4124, 4125,
+ 4126, 4127, 4128, 4129, 4130, 4131, 4132, 4133, 4134, 4135, 4136, 4137, 4138, 4139, 4140, 4141, 4142,
+ 4143, 4144, 4145, 4146, 4147, 4148, 4149, 4150, 4151, 4152, 4153, 4154, 4155, 4156, 4157, 4158, 4159,
+ 4160, 4161, 4162, 4163, 4164, 4165, 4166, 4167, 4168, 4169, 4170, 4171, 4172, 4173, 4174, 4175, 4176,
+ 4177, 4178, 4179, 4180, 4181, 4182, 4183, 4184, 4185, 4186, 4187, 4188, 4189, 4190, 4191, 4192, 4193,
+ 4194, 4195, 4196, 4197, 4198, 4199, 4200, 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4210,
+ 4211, 4212, 4213, 4214, 4215, 4216, 4217, 4218, 4219, 4220, 4221, 4222, 4223, 4224, 4225, 4226, 4227,
+ 4228, 4229, 4230, 4231, 4232, 4233, 4234, 4235, 4236, 4237, 4238, 4239, 4240, 4241, 4242, 4243, 4244,
+ 4245, 4246, 4247, 4248, 4249, 4250, 4251, 4252, 4253, 4254, 4255, 4256, 4257, 4258, 4259, 4260, 4261,
+ 4262, 4263, 4264, 4265, 4266, 4267, 4268, 4269, 4270, 4271, 4272, 4273, 4274, 4275, 4276, 4277, 4278,
+ 4279, 4280, 4281, 4282, 4283, 4284, 4285, 4286, 4287, 4288, 4289, 4290, 4291, 4292, 4293, 4294, 4295,
+ 4296, 4297, 4298, 4299, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4312,
+ 4313, 4314, 4315, 4316, 4317, 4318, 4319, 4320, 4321, 4322, 4323, 4324, 4325, 4326, 4327, 4328, 4329,
+ 4330, 4331, 4332, 4333, 4334, 4335, 4336, 4337, 4338, 4339, 4340, 4341, 4342, 4343, 4344, 4345, 4346,
+ 4347, 4348, 4349, 4350, 4351, 4352, 4353, 4354, 4355, 4356, 4357, 4358, 4359, 4360, 4361, 4362, 4363,
+ 4364, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4377, 4378, 4379, 4380,
+ 4381, 4382, 4383, 4384, 4385, 4386, 4387, 4388, 4389, 4390, 4391, 4392, 4393, 4394, 4395, 4396, 4397,
+ 4398, 4399, 4400, 4401, 4402, 4403, 4404, 4405, 4406, 4407, 4408, 4409, 4410, 4411, 4412, 4413, 4414,
+ 4415, 4416, 4417, 4418, 4419, 4420, 4421, 4422, 4423, 4424, 4425, 4426, 4427, 4428, 4429, 4430, 4431,
+ 4432, 4433, 4434, 4435, 4436, 4437, 4438, 4439, 4440, 4441, 4442, 4443, 4444, 4445, 4446, 4447, 4448,
+ 4449, 4450, 4451, 4452, 4453, 4454, 4455, 4456, 4457, 4458, 4459, 4460, 4461, 4462, 4463, 4464, 4465,
+ 4466, 4467, 4468, 4469, 4470, 4471, 4472, 4473, 4474, 4475, 4476,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_half",
+ torch.tensor([
+ 19, 20, 21, 22, 23, 24, 25, 26, 109, 110, 111, 112, 219, 220, 221, 222, 335, 336, 337, 338, 522, 523,
+ 524, 525, 526, 527, 528, 529, 534, 535, 536, 537, 554, 555, 556, 557, 584, 585, 586, 587, 595, 596, 597,
+ 598, 599, 600, 601, 602, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621,
+ 634, 635, 636, 637, 640, 641, 642, 643, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663,
+ 664, 665, 666, 675, 676, 677, 678, 684, 685, 686, 687, 689, 690, 698, 699, 700, 701, 710, 711, 716, 717,
+ 718, 719, 720, 721, 722, 741, 742, 743, 744, 749, 750, 751, 752, 776, 777, 778, 779, 780, 781, 782, 787,
+ 788, 789, 790, 791, 792, 793, 794, 800, 801, 810, 811, 812, 813, 817, 818, 819, 820, 830, 831, 832, 833,
+ 834, 835, 836, 839, 843, 844, 845, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862,
+ 863, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 886, 887, 888, 889, 890, 891, 892, 893, 894,
+ 895, 900, 901, 910, 911, 912, 913, 914, 915, 916, 917, 920, 921, 925, 930, 931, 932, 933, 934, 935, 936,
+ 937, 938, 940, 941, 956, 957, 973, 974, 975, 976, 981, 982, 983, 984, 987, 988, 989, 990, 996, 997, 998,
+ 1004, 1005, 1009, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1035, 1036, 1037, 1038, 1039,
+ 1040, 1041, 1042, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1066, 1067,
+ 1069, 1070, 1071, 1072, 1073, 1074, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1089, 1090,
+ 1091, 1094, 1095, 1097, 1098, 1099, 1100, 1102, 1103, 1104, 1105, 1106, 1107, 1109, 1110, 1111, 1112,
+ 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1130, 1131, 1133, 1136, 1137, 1138, 1139, 1140, 1141, 1145,
+ 1148, 1149, 1156, 1157, 1158, 1159, 1160, 1165, 1166, 1167, 1171, 1172, 1173, 1174, 1177, 1178, 1179,
+ 1180, 1185, 1186, 1187, 1188, 1191, 1192, 1196, 1197, 1198, 1199, 1203, 1204, 1205, 1206, 1207, 1208,
+ 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1219, 1220, 1221, 1222, 1223, 1231, 1234, 1235, 1236, 1237,
+ 1238, 1239, 1240, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258,
+ 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275,
+ 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1285, 1286, 1288, 1290, 1291, 1295, 1296, 1297, 1300, 1301,
+ 1302, 1303, 1304, 1305, 1306, 1307, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1327,
+ 1328, 1330, 1332, 1333, 1334, 1335, 1359, 1360, 1379, 1380, 1381, 1382, 1392, 1393, 1394, 1395, 1406,
+ 1407, 1408, 1409, 1488, 1613, 1614, 1615, 1616, 1619, 1620, 1621, 1622, 1627, 1628, 1629, 1630, 1631,
+ 1632, 1633, 1634, 1635, 1636, 1637, 1726, 1727, 1752, 1753, 1754, 1755, 1760, 1761, 1762, 1772, 1783,
+ 1784, 1785, 1786, 1822, 1828, 1829, 1833, 1834, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845,
+ 1853, 1855, 1856, 1857, 1858, 1859, 1870, 1882, 1883, 1884, 1885, 1912, 1913, 1916, 1929, 1930, 1931,
+ 1932, 1933, 1934, 1935, 1936, 1937, 1940, 1941, 1960, 1961, 1962, 1963, 1982, 1983, 1984, 1985, 2000,
+ 2001, 2002, 2003, 2005, 2006, 2007, 2008, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2027, 2028,
+ 2031, 2032, 2036, 2084, 2085, 2086, 2087, 2088, 2089, 2090, 2091, 2123, 2124, 2128, 2129, 2130, 2131,
+ 2132, 2133, 2144, 2145, 2146, 2147, 2149, 2150, 2151, 2165, 2166, 2167, 2168, 2176, 2177, 2178, 2179,
+ 2180, 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196,
+ 2197, 2198, 2199, 2200, 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2210, 2211, 2212, 2213,
+ 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2230,
+ 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247,
+ 2248, 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259, 2260, 2261, 2262, 2263, 2264,
+ 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2280, 2281,
+ 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298,
+ 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315,
+ 2316, 2317, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332,
+ 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349,
+ 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366,
+ 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2375, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383,
+ 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400,
+ 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417,
+ 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434,
+ 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451,
+ 2452, 2453, 2454, 2455, 2456, 2457, 2458, 2459, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468,
+ 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485,
+ 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502,
+ 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519,
+ 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536,
+ 2537, 2538, 2539, 2540, 2541, 2542, 2543, 2544, 2545, 2546, 2547, 2548, 2549, 2550, 2551, 2552, 2553,
+ 2554, 2555, 2556, 2557, 2558, 2559, 2560, 2561, 2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570,
+ 2571, 2572, 2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583, 2584, 2585, 2586, 2587,
+ 2588, 2589, 2590, 2591, 2592, 2593, 2594, 2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604,
+ 2605, 2606, 2607, 2608, 2609, 2610, 2611, 2612, 2613, 2614, 2615, 2616, 2617, 2618, 2619, 2620, 2621,
+ 2622, 2623, 2624, 2625, 2626, 2627, 2628, 2629, 2630, 2631, 2632, 2633, 2634, 2635, 2636, 2637, 2638,
+ 2639, 2640, 2641, 2642, 2643, 2644, 2645, 2646, 2647, 2648, 2649, 2650, 2651, 2652, 2653, 2654, 2655,
+ 2656, 2657, 2658, 2659, 2660, 2661, 2662, 2663, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2671, 2672,
+ 2673, 2674, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2683, 2684, 2685, 2686, 2687, 2688, 2689,
+ 2690, 2691, 2692, 2693, 2694, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706,
+ 2707, 2708, 2709, 2710, 2711, 2712, 2713, 2714, 2715, 2716, 2717, 2718, 2719, 2720, 2721, 2722, 2723,
+ 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740,
+ 2741, 2742, 2743, 2744, 2745, 2746, 2747, 2748, 2749, 2750, 2751, 2752, 2753, 2754, 2755, 2756, 2757,
+ 2758, 2759, 2760, 2761, 2762, 2763, 2764, 2765, 2766, 2767, 2768, 2769, 2770, 2771, 2772, 2773, 2774,
+ 2775, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2788, 2789, 2790, 2791,
+ 2792, 2793, 2794, 2795, 2796, 2797, 2798, 2799, 2800, 2801, 2802, 2803, 2804, 2805, 2806, 2807, 2808,
+ 2809, 2810, 2811, 2812, 2813, 2814, 2815, 2816, 2817, 2818, 2819, 2820, 2821, 2822, 2823, 2824, 2825,
+ 2826, 2827, 2828, 2829, 2830, 2831, 2832, 2833, 2834, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842,
+ 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2850, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859,
+ 2860, 2861, 2862, 2863, 2864, 2865, 2866, 2867, 2868, 2869, 2870, 2871, 2872, 2873, 2874, 2875, 2876,
+ 2877, 2878, 2879, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2893,
+ 2894, 2895, 2896, 2897, 2898, 2899, 2900, 2901, 2902, 2903, 2904, 2905, 2906, 2907, 2908, 2909, 2910,
+ 2911, 2912, 2913, 2914, 2915, 2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2925, 2926, 2927,
+ 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944,
+ 2945, 2946, 2947, 2948, 2949, 2950, 2951, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2961,
+ 2962, 2963, 2964, 2965, 2966, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2974, 2975, 2976, 2977, 2978,
+ 2979, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987, 2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995,
+ 2996, 2997, 2998, 2999, 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012,
+ 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029,
+ 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046,
+ 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063,
+ 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080,
+ 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097,
+ 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107, 3108, 3109, 3110, 3111, 3112, 3113, 3114,
+ 3115, 3116, 3117, 3118, 3119, 3120, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3131,
+ 3132, 3133, 3134, 3135, 3136, 3137, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3145, 3146, 3147, 3148,
+ 3149, 3150, 3151, 3152, 3153, 3154, 3155, 3156, 3157, 3158, 3159, 3160, 3161, 3162, 3163, 3164, 3165,
+ 3166, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3175, 3176, 3177, 3178, 3179, 3180, 3181, 3182,
+ 3183, 3184, 3185, 3222, 3223, 3248, 3249, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3284, 3285, 3290,
+ 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307,
+ 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3319, 3320, 3321, 3322, 3323, 3324,
+ 3325, 3326, 3327, 3328, 3329, 3330, 3331, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341,
+ 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3351, 3352, 3353, 3354, 3355, 3356, 3357, 3358,
+ 3359, 3360, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3370, 3371, 3372, 3373, 3374, 3375,
+ 3376, 3377, 3378, 3379, 3380, 3381, 3382, 3383, 3384, 3385, 3386, 3387, 3388, 3389, 3390, 3391, 3392,
+ 3393, 3394, 3395, 3396, 3397, 3398, 3399, 3400, 3401, 3402, 3403, 3404, 3405, 3406, 3407, 3408, 3409,
+ 3410, 3411, 3412, 3413, 3414, 3415, 3416, 3417, 3418, 3419, 3420, 3421, 3422, 3423, 3424, 3425, 3426,
+ 3427, 3428, 3429, 3430, 3431, 3432, 3433, 3434, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443,
+ 3444, 3445, 3446, 3447, 3448, 3449, 3450, 3451, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460,
+ 3461, 3462, 3463, 3464, 3465, 3466, 3467, 3468, 3469, 3470, 3471, 3472, 3473, 3474, 3475, 3476, 3477,
+ 3478, 3479, 3480, 3481, 3482, 3483, 3484, 3485, 3486, 3487, 3488, 3489, 3490, 3491, 3492, 3493, 3494,
+ 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511,
+ 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528,
+ 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545,
+ 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562,
+ 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3585, 3586, 3589, 3590,
+ 3591, 3592, 3597, 3602, 3603, 3606, 3607, 3608, 3609, 3610, 3612, 3613, 3615, 3616, 3617, 3618, 3619,
+ 3620, 3621, 3622, 3627, 3631, 3632, 3633, 3638, 3639, 3640, 3641, 3642, 3645, 3647, 3648, 3651, 3657,
+ 3661, 3668, 3669, 3674, 3675, 3682, 3683, 3684, 3686, 3687, 3688, 3689, 3690, 3692, 3694, 3696, 3699,
+ 3700, 3702, 3704, 3705, 3706, 3708, 3710, 3711, 3712, 3718, 3719, 3720, 3721, 3723, 3729, 3731, 3732,
+ 3733, 3735, 3736, 3741, 3743, 3744, 3746, 3747, 3748, 3749, 3750, 3751, 3755, 3758, 3759, 3763, 3764,
+ 3765, 3766, 3767, 3768, 3770, 3773, 3774, 3775, 3776, 3777, 3778, 3779, 3780, 3781, 3782, 3783, 3784,
+ 3785, 3786, 3787, 3788, 3789, 3790, 3791, 3792, 3793, 3794, 3795, 3796, 3797, 3798, 3799, 3800, 3801,
+ 3802, 3803, 3804, 3805, 3806, 3930, 4477, 4478, 4479, 4480, 4481, 4482, 4483, 4484, 4485, 4486, 4487,
+ 4488, 4489, 4490, 4491, 4492, 4493, 4494, 4495, 4496, 4497, 4498, 4499, 4500, 4501, 4502, 4503, 4504,
+ 4505, 4506, 4507, 4508, 4509, 4510, 4511, 4512, 4513, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4521,
+ 4522, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538,
+ 4539, 4540, 4541, 4542, 4543, 4544, 4545, 4546, 4547, 4548, 4549, 4550, 4551, 4552, 4553, 4554, 4555,
+ 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4568, 4569, 4570, 4571, 4572,
+ 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4585, 4586, 4587, 4588, 4589,
+ 4590, 4591, 4592, 4593, 4594, 4595, 4596, 4597, 4598, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606,
+ 4607, 4608, 4609, 4610, 4611, 4612, 4613, 4614, 4615, 4616, 4617, 4618, 4619, 4620, 4621, 4622, 4623,
+ 4624, 4625, 4626, 4627, 4628, 4629, 4630, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4639, 4640,
+ 4641, 4642, 4643, 4644, 4645, 4646, 4647, 4648, 4649, 4650, 4651, 4652, 4653, 4654, 4655, 4656, 4657,
+ 4658, 4659, 4660, 4661, 4662, 4663, 4664, 4665, 4666, 4667, 4668, 4669, 4670, 4671, 4672, 4673, 4674,
+ 4675, 4676, 4677, 4678, 4679, 4680, 4681, 4682, 4683, 4684, 4685, 4686, 4687, 4688, 4689, 4690, 4691,
+ 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4699, 4700, 4701, 4702, 4703, 4704, 4705, 4706, 4707, 4708,
+ 4709, 4710, 4711, 4712, 4713, 4714, 4715, 4716, 4717, 4718, 4719, 4720, 4721, 4722, 4723, 4724, 4725,
+ 4726, 4727, 4728, 4729, 4730, 4731, 4732, 4733, 4734, 4735, 4736, 4737, 4738, 4739, 4740, 4741, 4742,
+ 4743, 4744, 4745, 4746, 4747, 4748, 4749, 4750, 4751, 4752, 4753, 4754, 4755, 4756, 4757, 4758, 4759,
+ 4760, 4761, 4762, 4763, 4764, 4765, 4766, 4767, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776,
+ 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4792, 4793,
+ 4794, 4795, 4796, 4797, 4798, 4799, 4800, 4801, 4802, 4803, 4804, 4805, 4806, 4807, 4808, 4809, 4810,
+ 4811, 4812, 4813, 4814, 4815, 4816, 4817, 4818, 4819, 4820, 4821, 4822, 4823, 4824, 4825, 4826, 4827,
+ 4828, 4829, 4830, 4831, 4832, 4833, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4843, 4844,
+ 4845, 4846, 4847, 4848, 4849, 4850, 4851, 4852, 4853, 4854, 4855, 4856, 4857, 4858, 4859, 4860, 4861,
+ 4862, 4863, 4864, 4865, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4876, 4877, 4878,
+ 4879, 4880, 4881, 4882, 4883, 4884, 4885, 4886, 4887, 4888, 4889, 4890, 4891, 4892, 4893, 4894, 4895,
+ 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912,
+ 4913, 4914, 4915, 4916, 4917, 4918, 4919, 4920, 4921, 4922, 4923, 4924, 4925, 4926, 4927, 4928, 4929,
+ 4930, 4931, 4932, 4933, 4934, 4935, 4936, 4937, 4938, 4939, 4940, 4941, 4942, 4943, 4944, 4945, 4946,
+ 4947, 4948, 4949, 4950, 4951, 4952, 4953, 4954, 4955, 4956, 4957, 4958, 4959, 4960, 4961, 4962, 4963,
+ 4964, 4965, 4966, 4967, 4968, 4969, 4970, 4971, 4972, 4973, 4974, 4975, 4976, 4977, 4978, 4979, 4980,
+ 4981, 4982, 4983, 4984, 4985, 4986, 4987, 4988, 4989, 4990, 4991, 4992, 4993, 4994, 4995, 4996, 4997,
+ 4998, 4999, 5000, 5001, 5002, 5003, 5004, 5005, 5006, 5007, 5008, 5009, 5010, 5011, 5012, 5013, 5014,
+ 5015, 5016, 5017, 5018, 5019, 5020, 5021, 5022
+ ])
+ )
+
+ # remove the intersection with neck from scalp and get the region for hair
+ face_and_neck = torch.cat([self.v.face, self.v.neck]).unique()
+ # get the intersection between scalp and face_and_neck
+ uniques, counts = torch.cat([self.v.scalp, face_and_neck]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.scalp, intersection]).unique(return_counts=True)
+ hair = uniques[counts == 1]
+ self.v.register_buffer("hair", hair)
+
+ # unions
+ self.v.register_buffer("ears", torch.cat([self.v.right_ear, self.v.left_ear]))
+ self.v.register_buffer("eyeballs", torch.cat([self.v.right_eyeball, self.v.left_eyeball]))
+ self.v.register_buffer("irises", torch.cat([self.v.right_iris, self.v.left_iris]))
+ self.v.register_buffer("left_eye", torch.cat([self.v.left_eye_region, self.v.left_eyeball]))
+ self.v.register_buffer("right_eye", torch.cat([self.v.right_eye_region, self.v.right_eyeball]))
+ self.v.register_buffer("eyelids", torch.cat([self.v.left_eyelid, self.v.right_eyelid]))
+ self.v.register_buffer("lip_inside_ring", torch.cat(
+ [self.v.lip_inside_ring_upper, self.v.lip_inside_ring_lower, torch.tensor([1594, 2730])]))
+
+ # remove the intersection with irises from eyeballs and get the region for scleras
+ uniques, counts = torch.cat([self.v.eyeballs, self.v.irises]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.eyeballs, intersection]).unique(return_counts=True)
+ sclerae = uniques[counts == 1]
+ self.v.register_buffer("sclerae", sclerae)
+
+ # skin
+ skin_except = ["eyeballs", "hair", "lips_tight", "boundary"]
+ if self.num_verts == 5083:
+ skin_except.append("teeth")
+ skin = self.get_vid_except_region(skin_except)
+ self.v.register_buffer("skin", skin)
+
+ def construct_vid_table(self):
+ self.vid_to_region = defaultdict(list) # vertex id -> region name
+ for region_name, v_mask in self.v:
+ for v_id in v_mask:
+ self.vid_to_region[v_id.item()].append(region_name)
+
+ def process_face_mask(self, faces):
+
+ face_masks = defaultdict(list) # region name -> face id
+ for f_id, f in enumerate(faces):
+ counters = defaultdict(int)
+ for v_id in f:
+ for region_name in self.vid_to_region[v_id.item()]:
+ counters[region_name] += 1
+
+ for region_name, count in counters.items():
+ if count >= 3: # create straight boundaries, with seams
+ # if count > 1: # create zigzag boundaries, no seams
+ face_masks[region_name].append(f_id)
+
+ self.f = BufferContainer()
+ for region_name, f_mask in face_masks.items():
+ self.f.register_buffer(region_name, torch.tensor(f_mask, dtype=torch.long))
+
+ def process_face_clusters(self, face_clusters):
+ """ Construct a lookup table from face id to cluster id.
+
+ cluster #0: background
+ cluster #1: foreground
+ cluster #2: faces in face_clusters[0]
+ cluster #3: faces in face_clusters[1]
+ ...
+ """
+ fid2cid = torch.ones(self.num_faces + 1, dtype=torch.long) # faces are always treated as foreground
+ for cid, cluster in enumerate(face_clusters):
+ try:
+ fids = self.get_fid_by_region([cluster])
+ except Exception as e:
+ continue
+ fid2cid[
+ fids] = cid + 2 # reserve cluster #0 for the background and #1 for faces that do not belong to any cluster
+ self.register_buffer("fid2cid", fid2cid)
+
+ def process_vt_mask(self, faces, faces_t):
+ vt_masks = defaultdict(list) # region name -> vt id
+ for f_id, (face, face_t) in enumerate(zip(faces, faces_t)):
+ for v_id, vt_id in zip(face, face_t):
+ for region_name in self.vid_to_region[v_id.item()]:
+ vt_masks[region_name].append(vt_id.item())
+
+ self.vt = BufferContainer()
+ for region_name, vt_mask in vt_masks.items():
+ self.vt.register_buffer(region_name, torch.tensor(vt_mask, dtype=torch.long))
+
+ def get_vid_by_region(self, regions, keep_order=False):
+ """Get vertex indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ vid = torch.cat([self.v.get_buffer(k) for k in regions])
+ if keep_order:
+ return vid
+ else:
+ return vid.unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_vid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.v.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the vertex indicies that are not included by regions
+ vert_idx = torch.arange(0, self.num_verts, device=indices.device)
+ combined = torch.cat((indices, vert_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_by_region(self, regions):
+ """Get face indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ return torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_fid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the face indicies that are not included by regions
+ face_idx = torch.arange(0, self.num_faces, device=indices.device)
+ combined = torch.cat((indices, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_except_fids(self, fids):
+ # get the face indicies that are not included
+ face_idx = torch.arange(0, self.num_faces, device=fids.device)
+ combined = torch.cat((fids, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+
+if __name__ == '__main__':
+ flame_model = FlameHead(shape_params=300, expr_params=100)
diff --git a/lam/models/rendering/flame_model/lbs.py b/lam/models/rendering/flame_model/lbs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c377f6484a67e75b36fb35ccf28200ab122d5af
--- /dev/null
+++ b/lam/models/rendering/flame_model/lbs.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+
+
+def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
+ """Calculates the rotation matrices for a batch of rotation vectors
+ Parameters
+ ----------
+ rot_vecs: torch.tensor Nx3
+ array of N axis-angle vectors
+ Returns
+ -------
+ R: torch.tensor Nx3x3
+ The rotation matrices for the given axis-angle parameters
+ """
+
+ batch_size = rot_vecs.shape[0]
+ device = rot_vecs.device
+
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
+ rot_dir = rot_vecs / angle
+
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
+
+ # Bx1 arrays
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
+
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
+ (batch_size, 3, 3)
+ )
+
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ return rot_mat
+
+
+def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
+ """Calculates landmarks by barycentric interpolation
+
+ Parameters
+ ----------
+ vertices: torch.tensor BxVx3, dtype = torch.float32
+ The tensor of input vertices
+ faces: torch.tensor Fx3, dtype = torch.long
+ The faces of the mesh
+ lmk_faces_idx: torch.tensor L, dtype = torch.long
+ The tensor with the indices of the faces used to calculate the
+ landmarks.
+ lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
+ The tensor of barycentric coordinates that are used to interpolate
+ the landmarks
+
+ Returns
+ -------
+ landmarks: torch.tensor BxLx3, dtype = torch.float32
+ The coordinates of the landmarks for each mesh in the batch
+ """
+ # Extract the indices of the vertices for each face
+ # BxLx3
+ batch_size, num_verts = vertices.shape[:2]
+ device = vertices.device
+
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
+ batch_size, -1, 3
+ )
+
+ lmk_faces += (
+ torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1)
+ * num_verts
+ )
+
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
+
+ landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords])
+ return landmarks
+
+
+def lbs(
+ pose,
+ v_shaped,
+ posedirs,
+ J_regressor,
+ parents,
+ lbs_weights,
+ pose2rot=True,
+ dtype=torch.float32,
+):
+ """Performs Linear Blend Skinning with the given shape and pose parameters
+
+ Parameters
+ ----------
+ betas : torch.tensor BxNB
+ The tensor of shape parameters
+ pose : torch.tensor Bx(J + 1) * 3
+ The pose parameters in axis-angle format
+ v_template: torch.tensor BxVx3
+ The template mesh that will be deformed
+ shapedirs : torch.tensor 1xNB
+ The tensor of PCA shape displacements
+ posedirs : torch.tensor Px(V * 3)
+ The pose PCA coefficients
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from
+ the position of the vertices
+ parents: torch.tensor J
+ The array that describes the kinematic tree for the model
+ lbs_weights: torch.tensor N x V x (J + 1)
+ The linear blend skinning weights that represent how much the
+ rotation matrix of each part affects each vertex
+ pose2rot: bool, optional
+ Flag on whether to convert the input pose tensor to rotation
+ matrices. The default value is True. If False, then the pose tensor
+ should already contain rotation matrices and have a size of
+ Bx(J + 1)x9
+ dtype: torch.dtype, optional
+
+ Returns
+ -------
+ verts: torch.tensor BxVx3
+ The vertices of the mesh after applying the shape and pose
+ displacements.
+ joints: torch.tensor BxJx3
+ The joints of the model
+ """
+
+ batch_size = pose.shape[0]
+ device = pose.device
+
+ # Get the joints
+ # NxJx3 array
+ J = vertices2joints(J_regressor, v_shaped)
+
+ # 3. Add pose blend shapes
+ # N x J x 3 x 3
+ ident = torch.eye(3, dtype=dtype, device=device)
+ if pose2rot:
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view(
+ [batch_size, -1, 3, 3]
+ )
+
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
+ # (N x P) x (P, V * 3) -> N x V x 3
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
+ else:
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
+ rot_mats = pose.view(batch_size, -1, 3, 3)
+
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(
+ batch_size, -1, 3
+ )
+
+ v_posed = pose_offsets + v_shaped
+
+ # 4. Get the global joint location
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones(
+ [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
+ )
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+
+ verts = v_homo[:, :, :3, 0]
+
+ return verts, J_transformed, A[:, 1]
+
+
+def vertices2joints(J_regressor, vertices):
+ """Calculates the 3D joint locations from the vertices
+
+ Parameters
+ ----------
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from the
+ position of the vertices
+ vertices : torch.tensor BxVx3
+ The tensor of mesh vertices
+
+ Returns
+ -------
+ torch.tensor BxJx3
+ The location of the joints
+ """
+
+ return torch.einsum("bik,ji->bjk", [vertices, J_regressor])
+
+
+def blend_shapes(betas, shape_disps):
+ """Calculates the per vertex displacement due to the blend shapes
+
+
+ Parameters
+ ----------
+ betas : torch.tensor Bx(num_betas)
+ Blend shape coefficients
+ shape_disps: torch.tensor Vx3x(num_betas)
+ Blend shapes
+
+ Returns
+ -------
+ torch.tensor BxVx3
+ The per-vertex displacement due to shape deformation
+ """
+
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
+ # i.e. Multiply each shape displacement by its corresponding beta and
+ # then sum them.
+ blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps])
+ return blend_shape
+
+
+def transform_mat(R, t):
+ """Creates a batch of transformation matrices
+ Args:
+ - R: Bx3x3 array of a batch of rotation matrices
+ - t: Bx3x1 array of a batch of translation vectors
+ Returns:
+ - T: Bx4x4 Transformation matrix
+ """
+ # No padding left or right, only add an extra row
+ return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
+
+
+def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
+ """
+ Applies a batch of rigid transformations to the joints
+
+ Parameters
+ ----------
+ rot_mats : torch.tensor BxNx3x3
+ Tensor of rotation matrices
+ joints : torch.tensor BxNx3
+ Locations of joints
+ parents : torch.tensor BxN
+ The kinematic tree of each object
+ dtype : torch.dtype, optional:
+ The data type of the created tensors, the default is torch.float32
+
+ Returns
+ -------
+ posed_joints : torch.tensor BxNx3
+ The locations of the joints after applying the pose rotations
+ rel_transforms : torch.tensor BxNx4x4
+ The relative (with respect to the root joint) rigid transformations
+ for all the joints
+ """
+
+ joints = torch.unsqueeze(joints, dim=-1)
+
+ rel_joints = joints.clone().contiguous()
+ rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]]
+
+ transforms_mat = transform_mat(rot_mats.view(-1, 3, 3), rel_joints.view(-1, 3, 1))
+ transforms_mat = transforms_mat.view(-1, joints.shape[1], 4, 4)
+
+ transform_chain = [transforms_mat[:, 0]]
+ for i in range(1, parents.shape[0]):
+ # Subtract the joint location at the rest pose
+ # No need for rotation, since it's identity when at rest
+ curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
+ transform_chain.append(curr_res)
+
+ transforms = torch.stack(transform_chain, dim=1)
+
+ # The last column of the transformations contains the posed joints
+ posed_joints = transforms[:, :, :3, 3]
+
+ joints_homogen = F.pad(joints, [0, 0, 0, 1])
+
+ rel_transforms = transforms - F.pad(
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
+ )
+
+ return posed_joints, rel_transforms
diff --git a/lam/models/rendering/gaussian_model.py b/lam/models/rendering/gaussian_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad7d749911abd4845f7b26de47fdbb87c02b63e2
--- /dev/null
+++ b/lam/models/rendering/gaussian_model.py
@@ -0,0 +1,177 @@
+import os
+from plyfile import PlyData, PlyElement
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+import copy
+from lam.models.rendering.utils.typing import *
+from lam.models.rendering.utils.utils import trunc_exp, MLP
+from einops import rearrange, repeat
+
+
+inverse_sigmoid = lambda x: np.log(x / (1 - x))
+
+
+class GaussianModel:
+ def __init__(self, xyz=None, opacity=None, rotation=None, scaling=None, shs=None, offset=None, ply_path=None, sh2rgb=False, albedo=None, lights=None) -> None:
+ self.xyz: Tensor = xyz
+ self.opacity: Tensor = opacity
+ self.rotation: Tensor = rotation
+ self.scaling: Tensor = scaling
+ self.shs: Tensor = shs
+ self.albedo: Tensor = albedo
+ self.offset: Tensor = offset
+ self.lights: Tensor = lights
+ if ply_path is not None:
+ self.load_ply(ply_path, sh2rgb=sh2rgb)
+
+ def update_lights(self, lights):
+ self.lights = lights
+
+ def update_albedo(self, albedo):
+ self.albedo = albedo
+
+ def update_shs(self, shs):
+ self.shs = shs
+
+ def to_cuda(self):
+ self.xyz = self.xyz.cuda()
+ self.opacity = self.opacity.cuda()
+ self.rotation = self.rotation.cuda()
+ self.scaling = self.scaling.cuda()
+ self.shs = self.shs.cuda()
+ self.offset = self.offset.cuda()
+ self.albedo = self.albedo.cuda()
+
+ def construct_list_of_attributes(self):
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+ if len(self.shs.shape) == 2:
+ features_dc = self.shs[:, :3].unsqueeze(1)
+ features_rest = self.shs[:, 3:].unsqueeze(1)
+ else:
+ features_dc = self.shs[:, :1]
+ features_rest = self.shs[:, 1:]
+ for i in range(features_dc.shape[1]*features_dc.shape[2]):
+ l.append('f_dc_{}'.format(i))
+ for i in range(features_rest.shape[1]*features_rest.shape[2]):
+ l.append('f_rest_{}'.format(i))
+ l.append('opacity')
+ for i in range(self.scaling.shape[1]):
+ l.append('scale_{}'.format(i))
+ for i in range(self.rotation.shape[1]):
+ l.append('rot_{}'.format(i))
+ return l
+
+ def save_ply(self, path, rgb2sh=False, offset2xyz=False, albedo2rgb=False):
+ if offset2xyz:
+ xyz = self.offset.detach().cpu().float().numpy()
+ else:
+ xyz = self.xyz.detach().cpu().float().numpy()
+ if albedo2rgb:
+ self.shs = self.albedo
+ normals = np.zeros_like(xyz)
+ if len(self.shs.shape) == 2:
+ features_dc = self.shs[:, :3].unsqueeze(1).float()
+ features_rest = self.shs[:, 3:].unsqueeze(1).float()
+ else:
+ features_dc = self.shs[:, :1].float()
+ features_rest = self.shs[:, 1:].float()
+ f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ if rgb2sh:
+ from lam.models.rendering.utils.sh_utils import RGB2SH
+ f_dc = RGB2SH(f_dc)
+ opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().float().numpy())
+ scale = np.log(self.scaling.detach().cpu().float().numpy())
+ rotation = self.rotation.detach().cpu().float().numpy()
+
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+ PlyData([el]).write(path)
+
+ def save_ply_nodeact(self, path, rgb2sh=False, albedo2rgb=False):
+ if albedo2rgb:
+ self.shs = self.albedo
+ xyz = self.xyz.detach().cpu().float().numpy()
+ normals = np.zeros_like(xyz)
+ if len(self.shs.shape) == 2:
+ features_dc = self.shs[:, :3].unsqueeze(1).float()
+ features_rest = self.shs[:, 3:].unsqueeze(1).float()
+ else:
+ features_dc = self.shs[:, :1].float()
+ features_rest = self.shs[:, 1:].float()
+ f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ if rgb2sh:
+ from lam.models.rendering.utils.sh_utils import RGB2SH
+ f_dc = RGB2SH(f_dc)
+ opacities = self.opacity.detach().cpu().float().numpy()
+ scale = self.scaling.detach().cpu().float().numpy()
+ rotation = self.rotation.detach().cpu().float().numpy()
+
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+ PlyData([el]).write(path)
+
+ def load_ply(self, path, sh2rgb=False):
+ plydata = PlyData.read(path)
+
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"])), axis=1)
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ self.sh_degree = 0
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.sh_degree + 1) ** 2 - 1))
+
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ self.xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cpu").requires_grad_(False))
+ self.features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cpu").transpose(1, 2).contiguous().requires_grad_(False))
+ if sh2rgb:
+ from lam.models.rendering.utils.sh_utils import SH2RGB
+ self.features_dc = SH2RGB(self.features_dc)
+ self.features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cpu").transpose(1, 2).contiguous().requires_grad_(False))
+ self.shs = torch.cat([self.features_dc, self.features_rest], dim=1)
+ self.opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cpu").requires_grad_(False))
+ self.scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cpu").requires_grad_(False))
+ self.rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cpu").requires_grad_(False))
+ self.offset = nn.Parameter(torch.zeros_like(self.xyz).requires_grad_(False))
+ self.albedo = nn.Parameter(torch.zeros_like(self.shs).requires_grad_(False))
+ self.lights = nn.Parameter(torch.zeros_like(self.shs).requires_grad_(False))
+ if sh2rgb:
+ self.opacity = nn.functional.sigmoid(self.opacity)
+ self.scaling = trunc_exp(self.scaling)
+
+ self.active_sh_degree = self.sh_degree
diff --git a/lam/models/rendering/gs_renderer.py b/lam/models/rendering/gs_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..789127894adbc3811ba5d5410eb9216bf8d5e764
--- /dev/null
+++ b/lam/models/rendering/gs_renderer.py
@@ -0,0 +1,939 @@
+import os
+from dataclasses import dataclass, field
+from collections import defaultdict
+try:
+ from diff_gaussian_rasterization_wda import GaussianRasterizationSettings, GaussianRasterizer
+except:
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
+from plyfile import PlyData, PlyElement
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+import copy
+from diffusers.utils import is_torch_version
+from lam.models.rendering.flame_model.flame import FlameHeadSubdivided
+from lam.models.transformer import TransformerDecoder
+from pytorch3d.transforms import matrix_to_quaternion
+from lam.models.rendering.utils.typing import *
+from lam.models.rendering.utils.utils import trunc_exp, MLP
+from lam.models.rendering.gaussian_model import GaussianModel
+from einops import rearrange, repeat
+from pytorch3d.ops.points_normals import estimate_pointcloud_normals
+os.environ["PYOPENGL_PLATFORM"] = "egl"
+from pytorch3d.structures import Meshes, Pointclouds
+from pytorch3d.renderer import (
+ AmbientLights,
+ PerspectiveCameras,
+ SoftSilhouetteShader,
+ SoftPhongShader,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRendererWithFragments,
+ MeshRasterizer,
+ TexturesVertex,
+)
+from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend
+import lam.models.rendering.utils.mesh_utils as mesh_utils
+from lam.models.rendering.utils.point_utils import depth_to_normal
+from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes
+
+inverse_sigmoid = lambda x: np.log(x / (1 - x))
+
+
+def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+
+ C2W = np.linalg.inv(Rt)
+ cam_center = C2W[:3, 3]
+ cam_center = (cam_center + translate) * scale
+ C2W[:3, 3] = cam_center
+ Rt = np.linalg.inv(C2W)
+ return np.float32(Rt)
+
+def getProjectionMatrix(znear, zfar, fovX, fovY):
+ tanHalfFovY = math.tan((fovY / 2))
+ tanHalfFovX = math.tan((fovX / 2))
+
+ top = tanHalfFovY * znear
+ bottom = -top
+ right = tanHalfFovX * znear
+ left = -right
+
+ P = torch.zeros(4, 4)
+
+ z_sign = 1.0
+
+ P[0, 0] = 2.0 * znear / (right - left)
+ P[1, 1] = 2.0 * znear / (top - bottom)
+ P[0, 2] = (right + left) / (right - left)
+ P[1, 2] = (top + bottom) / (top - bottom)
+ P[3, 2] = z_sign
+ P[2, 2] = z_sign * zfar / (zfar - znear)
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
+ return P
+
+def intrinsic_to_fov(intrinsic, w, h):
+ fx, fy = intrinsic[0, 0], intrinsic[1, 1]
+ fov_x = 2 * torch.arctan2(w, 2 * fx)
+ fov_y = 2 * torch.arctan2(h, 2 * fy)
+ return fov_x, fov_y
+
+
+class Camera:
+ def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None:
+ self.FoVx = FoVx
+ self.FoVy = FoVy
+ self.height = int(height)
+ self.width = int(width)
+ self.world_view_transform = w2c.transpose(0, 1)
+ self.intrinsic = intrinsic
+
+ self.zfar = 100.0
+ self.znear = 0.01
+
+ self.trans = trans
+ self.scale = scale
+
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
+
+ @staticmethod
+ def from_c2w(c2w, intrinsic, height, width):
+ w2c = torch.inverse(c2w)
+ FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device))
+ return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width)
+
+
+class GSLayer(nn.Module):
+ def __init__(self, in_channels, use_rgb,
+ clip_scaling=0.2,
+ init_scaling=-5.0,
+ scale_sphere=False,
+ init_density=0.1,
+ sh_degree=None,
+ xyz_offset=True,
+ restrict_offset=True,
+ xyz_offset_max_step=None,
+ fix_opacity=False,
+ fix_rotation=False,
+ use_fine_feat=False,
+ pred_res=False,
+ ):
+ super().__init__()
+ self.clip_scaling = clip_scaling
+ self.use_rgb = use_rgb
+ self.restrict_offset = restrict_offset
+ self.xyz_offset = xyz_offset
+ self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32
+ self.fix_opacity = fix_opacity
+ self.fix_rotation = fix_rotation
+ self.use_fine_feat = use_fine_feat
+ self.scale_sphere = scale_sphere
+ self.pred_res = pred_res
+
+ self.attr_dict ={
+ "shs": (sh_degree + 1) ** 2 * 3,
+ "scaling": 3 if not scale_sphere else 1,
+ "xyz": 3,
+ "opacity": None,
+ "rotation": None
+ }
+ if not self.fix_opacity:
+ self.attr_dict["opacity"] = 1
+ if not self.fix_rotation:
+ self.attr_dict["rotation"] = 4
+
+ self.out_layers = nn.ModuleDict()
+ for key, out_ch in self.attr_dict.items():
+ if out_ch is None:
+ layer = nn.Identity()
+ else:
+ if key == "shs" and use_rgb:
+ out_ch = 3
+ if key == "shs":
+ shs_out_ch = out_ch
+ if pred_res:
+ layer = nn.Linear(in_channels+out_ch, out_ch)
+ else:
+ layer = nn.Linear(in_channels, out_ch)
+ # initialize
+ if not (key == "shs" and use_rgb):
+ if key == "opacity" and self.fix_opacity:
+ pass
+ elif key == "rotation" and self.fix_rotation:
+ pass
+ else:
+ nn.init.constant_(layer.weight, 0)
+ nn.init.constant_(layer.bias, 0)
+ if key == "scaling":
+ nn.init.constant_(layer.bias, init_scaling)
+ elif key == "rotation":
+ if not self.fix_rotation:
+ nn.init.constant_(layer.bias, 0)
+ nn.init.constant_(layer.bias[0], 1.0)
+ elif key == "opacity":
+ if not self.fix_opacity:
+ nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
+ self.out_layers[key] = layer
+
+ if self.use_fine_feat:
+ fine_shs_layer = nn.Linear(in_channels, shs_out_ch)
+ nn.init.constant_(fine_shs_layer.weight, 0)
+ nn.init.constant_(fine_shs_layer.bias, 0)
+ self.out_layers["fine_shs"] = fine_shs_layer
+
+ def forward(self, x, pts, x_fine=None, gs_raw_attr=None, ret_raw=False, vtx_sym_idxs=None):
+ assert len(x.shape) == 2
+ ret = {}
+ if ret_raw:
+ raw_attr = {}
+ ori_x = x
+ for k in self.attr_dict:
+ # if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity"]:
+ if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity", "rotation"]:
+ # print("==="*16*3, "\n\n\n"+"use sym mean.", "\n"+"==="*16*3)
+ # x = (x + x[vtx_sym_idxs.to(x.device), :]) / 2.
+ x = ori_x[vtx_sym_idxs.to(x.device), :]
+ else:
+ x = ori_x
+ layer =self.out_layers[k]
+ if self.pred_res and (not self.fix_opacity or k != "opacity") and (not self.fix_rotation or k != "rotation"):
+ v = layer(torch.cat([gs_raw_attr[k], x], dim=-1))
+ v = gs_raw_attr[k] + v
+ else:
+ v = layer(x)
+ if ret_raw:
+ raw_attr[k] = v
+ if k == "rotation":
+ if self.fix_rotation:
+ v = matrix_to_quaternion(torch.eye(3).type_as(x)[None,: , :].repeat(x.shape[0], 1, 1)) # constant rotation
+ else:
+ # assert len(x.shape) == 2
+ v = torch.nn.functional.normalize(v)
+ elif k == "scaling":
+ v = trunc_exp(v)
+ if self.scale_sphere:
+ assert v.shape[-1] == 1
+ v = torch.cat([v, v, v], dim=-1)
+ if self.clip_scaling is not None:
+ v = torch.clamp(v, min=0, max=self.clip_scaling)
+ elif k == "opacity":
+ if self.fix_opacity:
+ v = torch.ones_like(x)[..., 0:1]
+ else:
+ v = torch.sigmoid(v)
+ elif k == "shs":
+ if self.use_rgb:
+ v[..., :3] = torch.sigmoid(v[..., :3])
+ if self.use_fine_feat:
+ v_fine = self.out_layers["fine_shs"](x_fine)
+ v_fine = torch.tanh(v_fine)
+ v = v + v_fine
+ else:
+ if self.use_fine_feat:
+ v_fine = self.out_layers["fine_shs"](x_fine)
+ v = v + v_fine
+ v = torch.reshape(v, (v.shape[0], -1, 3))
+ elif k == "xyz":
+ # TODO check
+ if self.restrict_offset:
+ max_step = self.xyz_offset_max_step
+ v = (torch.sigmoid(v) - 0.5) * max_step
+ if self.xyz_offset:
+ pass
+ else:
+ assert NotImplementedError
+ ret["offset"] = v
+ v = pts + v
+ ret[k] = v
+
+ if ret_raw:
+ return GaussianModel(**ret), raw_attr
+ else:
+ return GaussianModel(**ret)
+
+
+class PointEmbed(nn.Module):
+ def __init__(self, hidden_dim=48, dim=128):
+ super().__init__()
+
+ assert hidden_dim % 6 == 0
+
+ self.embedding_dim = hidden_dim
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
+ e = torch.stack([
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6), e]),
+ ])
+ self.register_buffer('basis', e) # 3 x 16
+
+ self.mlp = nn.Linear(self.embedding_dim+3, dim)
+ self.norm = nn.LayerNorm(dim)
+
+ @staticmethod
+ def embed(input, basis):
+ projections = torch.einsum(
+ 'bnd,de->bne', input, basis)
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
+ return embeddings
+
+ def forward(self, input):
+ # input: B x N x 3
+ embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
+ embed = self.norm(embed)
+ return embed
+
+
+class CrossAttnBlock(nn.Module):
+ """
+ Transformer block that takes in a cross-attention condition.
+ Designed for SparseLRM architecture.
+ """
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float=None,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0., feedforward=False):
+ super().__init__()
+ # TODO check already apply normalization
+ # self.norm_q = nn.LayerNorm(inner_dim, eps=eps)
+ # self.norm_k = nn.LayerNorm(cond_dim, eps=eps)
+ self.norm_q = nn.Identity()
+ self.norm_k = nn.Identity()
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+
+ self.mlp = None
+ if feedforward:
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x, cond):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond]
+ x = self.cross_attn(self.norm_q(x), self.norm_k(cond), cond, need_weights=False)[0]
+ if self.mlp is not None:
+ before_sa = self.norm2(x)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm3(x))
+ return x
+
+
+class DecoderCrossAttn(nn.Module):
+ def __init__(self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None):
+ super().__init__()
+ self.query_dim = query_dim
+ self.context_dim = context_dim
+
+ self.cross_attn = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim,
+ num_heads=num_heads, feedforward=mlp,
+ eps=1e-5)
+ self.decode_with_extra_info = decode_with_extra_info
+ if decode_with_extra_info is not None:
+ if decode_with_extra_info["type"] == "dinov2p14_feat":
+ context_dim = decode_with_extra_info["cond_dim"]
+ self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim,
+ num_heads=num_heads, feedforward=False, eps=1e-5)
+ elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
+ from lam.models.encoders.dinov2_wrapper import Dinov2Wrapper
+ self.encoder = Dinov2Wrapper(model_name='dinov2_vits14_reg', freeze=False, encoder_feat_dim=384)
+ self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=384,
+ num_heads=num_heads, feedforward=False,
+ eps=1e-5)
+ elif decode_with_extra_info["type"] == "decoder_resnet18_feat":
+ from lam.models.encoders.xunet_wrapper import XnetWrapper
+ self.encoder = XnetWrapper(model_name='resnet18', freeze=False, encoder_feat_dim=64)
+ self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=64,
+ num_heads=num_heads, feedforward=False,
+ eps=1e-5)
+
+ def resize_image(self, image, multiply):
+ B, _, H, W = image.shape
+ new_h, new_w = math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply
+ image = F.interpolate(image, (new_h, new_w), align_corners=True, mode="bilinear")
+ return image
+
+ def forward(self, pcl_query, pcl_latent, extra_info=None):
+ out = self.cross_attn(pcl_query, pcl_latent)
+ if self.decode_with_extra_info is not None:
+ out_dict = {}
+ out_dict["coarse"] = out
+ if self.decode_with_extra_info["type"] == "dinov2p14_feat":
+ out = self.cross_attn_color(out, extra_info["image_feats"])
+ out_dict["fine"] = out
+ return out_dict
+ elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
+ img_feat = self.encoder(extra_info["image"])
+ out = self.cross_attn_color(out, img_feat)
+ out_dict["fine"] = out
+ return out_dict
+ elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat":
+ image = extra_info["image"]
+ image = self.resize_image(image, multiply=32)
+ img_feat = self.encoder(image)
+ out = self.cross_attn_color(out, img_feat)
+ out_dict["fine"] = out
+ return out_dict
+ return out
+
+
+class GS3DRenderer(nn.Module):
+ def __init__(self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim,
+ use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config,
+ expr_param_dim, shape_param_dim,
+ clip_scaling=0.2,
+ scale_sphere=False,
+ skip_decoder=False,
+ fix_opacity=False,
+ fix_rotation=False,
+ decode_with_extra_info=None,
+ gradient_checkpointing=False,
+ add_teeth=True,
+ teeth_bs_flag=False,
+ oral_mesh_flag=False,
+ **kwargs,
+ ):
+ super().__init__()
+ print(f"#########scale sphere:{scale_sphere}, add_teeth:{add_teeth}")
+ self.gradient_checkpointing = gradient_checkpointing
+ self.skip_decoder = skip_decoder
+ self.smpl_type = smpl_type
+ assert self.smpl_type == "flame"
+ self.sym_rend2 = True
+ self.teeth_bs_flag = teeth_bs_flag
+ self.oral_mesh_flag = oral_mesh_flag
+ self.render_rgb = kwargs.get("render_rgb", True)
+ print("==="*16*3, "\n Render rgb:", self.render_rgb, "\n"+"==="*16*3)
+
+ self.scaling_modifier = 1.0
+ self.sh_degree = sh_degree
+ if use_rgb:
+ self.sh_degree = 0
+
+ use_rgb = use_rgb
+
+ self.flame_model = FlameHeadSubdivided(
+ 300,
+ 100,
+ add_teeth=add_teeth,
+ add_shoulder=False,
+ flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl',
+ flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy",
+ flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj",
+ flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl",
+ subdivide_num=subdivide_num,
+ teeth_bs_flag=teeth_bs_flag,
+ oral_mesh_flag=oral_mesh_flag
+ )
+
+ if not self.skip_decoder:
+ self.pcl_embed = PointEmbed(dim=query_dim)
+
+ self.mlp_network_config = mlp_network_config
+ if self.mlp_network_config is not None:
+ self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config)
+
+ init_scaling = -5.0
+ self.gs_net = GSLayer(in_channels=query_dim,
+ use_rgb=use_rgb,
+ sh_degree=self.sh_degree,
+ clip_scaling=clip_scaling,
+ scale_sphere=scale_sphere,
+ init_scaling=init_scaling,
+ init_density=0.1,
+ xyz_offset=True,
+ restrict_offset=True,
+ xyz_offset_max_step=xyz_offset_max_step,
+ fix_opacity=fix_opacity,
+ fix_rotation=fix_rotation,
+ use_fine_feat=True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False,
+ )
+
+ def forward_single_view(self,
+ gs: GaussianModel,
+ viewpoint_camera: Camera,
+ background_color: Optional[Float[Tensor, "3"]],
+ ):
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+ screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
+ try:
+ screenspace_points.retain_grad()
+ except:
+ pass
+
+ bg_color = background_color
+ # Set up rasterization configuration
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+ GSRSettings = GaussianRasterizationSettings
+ GSR = GaussianRasterizer
+
+ raster_settings = GSRSettings(
+ image_height=int(viewpoint_camera.height),
+ image_width=int(viewpoint_camera.width),
+ tanfovx=tanfovx,
+ tanfovy=tanfovy,
+ bg=bg_color,
+ scale_modifier=self.scaling_modifier,
+ viewmatrix=viewpoint_camera.world_view_transform,
+ projmatrix=viewpoint_camera.full_proj_transform.float(),
+ sh_degree=self.sh_degree,
+ campos=viewpoint_camera.camera_center,
+ prefiltered=False,
+ debug=False
+ )
+
+ rasterizer = GSR(raster_settings=raster_settings)
+
+ means3D = gs.xyz
+ means2D = screenspace_points
+ opacity = gs.opacity
+
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+ # scaling / rotation by the rasterizer.
+ scales = None
+ rotations = None
+ cov3D_precomp = None
+ scales = gs.scaling
+ rotations = gs.rotation
+
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+ shs = None
+ colors_precomp = None
+ if self.gs_net.use_rgb:
+ colors_precomp = gs.shs.squeeze(1)
+ else:
+ shs = gs.shs
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
+ # torch.cuda.synchronize()
+ # with boxx.timeit():
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ raster_ret = rasterizer(
+ means3D = means3D.float(),
+ means2D = means2D.float(),
+ shs = shs.float() if not self.gs_net.use_rgb else None,
+ colors_precomp = colors_precomp.float() if colors_precomp is not None else None,
+ opacities = opacity.float(),
+ scales = scales.float(),
+ rotations = rotations.float(),
+ cov3D_precomp = cov3D_precomp
+ )
+ rendered_image, radii, rendered_depth, rendered_alpha = raster_ret
+
+ ret = {
+ "comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3]
+ "comp_rgb_bg": bg_color,
+ 'comp_mask': rendered_alpha.permute(1, 2, 0),
+ 'comp_depth': rendered_depth.permute(1, 2, 0),
+ }
+
+ return ret
+
+ def animate_gs_model(self, gs_attr: GaussianModel, query_points, flame_data, debug=False):
+ """
+ query_points: [N, 3]
+ """
+ device = gs_attr.xyz.device
+ if debug:
+ N = gs_attr.xyz.shape[0]
+ gs_attr.xyz = torch.ones_like(gs_attr.xyz) * 0.0
+
+ rotation = matrix_to_quaternion(torch.eye(3).float()[None, :, :].repeat(N, 1, 1)).to(device) # constant rotation
+ opacity = torch.ones((N, 1)).float().to(device) # constant opacity
+
+ gs_attr.opacity = opacity
+ gs_attr.rotation = rotation
+ # gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05
+ # print(gs_attr.shs.shape)
+
+ with torch.autocast(device_type=device.type, dtype=torch.float32):
+ # mean_3d = query_points + gs_attr.xyz # [N, 3]
+ mean_3d = gs_attr.xyz # [N, 3]
+
+ num_view = flame_data["expr"].shape[0] # [Nv, 100]
+ mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3]
+ query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1)
+
+ if self.teeth_bs_flag:
+ expr = torch.cat([flame_data['expr'], flame_data['teeth_bs']], dim=-1)
+ else:
+ expr = flame_data["expr"]
+ ret = self.flame_model.animation_forward(v_cano=mean_3d,
+ shape=flame_data["betas"].repeat(num_view, 1),
+ expr=expr,
+ rotation=flame_data["rotation"],
+ neck=flame_data["neck_pose"],
+ jaw=flame_data["jaw_pose"],
+ eyes=flame_data["eyes_pose"],
+ translation=flame_data["translation"],
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=False,
+ # static_offset=flame_data['static_offset'].to('cuda'),
+ static_offset=None,
+ )
+ mean_3d = ret["animated"]
+
+ gs_attr_list = []
+ for i in range(num_view):
+ gs_attr_copy = GaussianModel(xyz=mean_3d[i],
+ opacity=gs_attr.opacity,
+ rotation=gs_attr.rotation,
+ scaling=gs_attr.scaling,
+ shs=gs_attr.shs,
+ albedo=gs_attr.albedo,
+ lights=gs_attr.lights,
+ offset=gs_attr.offset) # [N, 3]
+ gs_attr_list.append(gs_attr_copy)
+
+ return gs_attr_list
+
+
+ def forward_gs_attr(self, x, query_points, flame_data, debug=False, x_fine=None, vtx_sym_idxs=None):
+ """
+ x: [N, C] Float[Tensor, "Np Cp"],
+ query_points: [N, 3] Float[Tensor, "Np 3"]
+ """
+ device = x.device
+ if self.mlp_network_config is not None:
+ x = self.mlp_net(x)
+ if x_fine is not None:
+ x_fine = self.mlp_net(x_fine)
+ gs_attr: GaussianModel = self.gs_net(x, query_points, x_fine, vtx_sym_idxs=vtx_sym_idxs)
+ return gs_attr
+
+
+ def get_query_points(self, flame_data, device):
+ with torch.no_grad():
+ with torch.autocast(device_type=device.type, dtype=torch.float32):
+ # print(flame_data["betas"].shape, flame_data["face_offset"].shape, flame_data["joint_offset"].shape)
+ # positions, _, transform_mat_neutral_pose = self.flame_model.get_query_points(flame_data, device=device) # [B, N, 3]
+ positions = self.flame_model.get_cano_verts(shape_params=flame_data["betas"]) # [B, N, 3]
+ # print(f"positions shape:{positions.shape}")
+
+ return positions, flame_data
+
+ def query_latent_feat(self,
+ positions: Float[Tensor, "*B N1 3"],
+ flame_data,
+ latent_feat: Float[Tensor, "*B N2 C"],
+ extra_info):
+ device = latent_feat.device
+ if self.skip_decoder:
+ gs_feats = latent_feat
+ assert positions is not None
+ else:
+ assert positions is None
+ if positions is None:
+ positions, flame_data = self.get_query_points(flame_data, device)
+
+ with torch.autocast(device_type=device.type, dtype=torch.float32):
+ pcl_embed = self.pcl_embed(positions)
+ gs_feats = pcl_embed
+
+ return gs_feats, positions, flame_data
+
+ def forward_single_batch(
+ self,
+ gs_list: list[GaussianModel],
+ c2ws: Float[Tensor, "Nv 4 4"],
+ intrinsics: Float[Tensor, "Nv 4 4"],
+ height: int,
+ width: int,
+ background_color: Optional[Float[Tensor, "Nv 3"]],
+ debug: bool=False,
+ ):
+ out_list = []
+ self.device = gs_list[0].xyz.device
+ for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)):
+ out_list.append(self.forward_single_view(
+ gs_list[v_idx],
+ Camera.from_c2w(c2w, intrinsic, height, width),
+ background_color[v_idx],
+ ))
+
+ out = defaultdict(list)
+ for out_ in out_list:
+ for k, v in out_.items():
+ out[k].append(v)
+ out = {k: torch.stack(v, dim=0) for k, v in out.items()}
+ out["3dgs"] = gs_list
+
+ return out
+
+ def get_sing_batch_smpl_data(self, smpl_data, bidx):
+ smpl_data_single_batch = {}
+ for k, v in smpl_data.items():
+ smpl_data_single_batch[k] = v[bidx] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3]
+ if k == "betas" or (k == "joint_offset") or (k == "face_offset"):
+ smpl_data_single_batch[k] = v[bidx:bidx+1] # e.g. betas: [B, 100] -> [1, 100]
+ return smpl_data_single_batch
+
+ def get_single_view_smpl_data(self, smpl_data, vidx):
+ smpl_data_single_view = {}
+ for k, v in smpl_data.items():
+ assert v.shape[0] == 1
+ if k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose"):
+ smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100]
+ else:
+ smpl_data_single_view[k] = v[:, vidx: vidx + 1] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3]
+ return smpl_data_single_view
+
+ def forward_gs(self,
+ gs_hidden_features: Float[Tensor, "B Np Cp"],
+ query_points: Float[Tensor, "B Np_q 3"],
+ flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
+ additional_features: Optional[dict] = None,
+ debug: bool = False,
+ **kwargs):
+
+ batch_size = gs_hidden_features.shape[0]
+
+ query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features,
+ additional_features)
+
+ gs_model_list = []
+ all_query_points = []
+ for b in range(batch_size):
+ all_query_points.append(query_points[b:b+1, :])
+ if isinstance(query_gs_features, dict):
+ ret_gs = self.forward_gs_attr(query_gs_features["coarse"][b], query_points[b], None, debug,
+ x_fine=query_gs_features["fine"][b], vtx_sym_idxs=None)
+ else:
+ ret_gs = self.forward_gs_attr(query_gs_features[b], query_points[b], None, debug, vtx_sym_idxs=None)
+
+ ret_gs.update_albedo(ret_gs.shs.clone())
+
+ gs_model_list.append(ret_gs)
+
+ query_points = torch.cat(all_query_points, dim=0)
+ return gs_model_list, query_points, flame_data, query_gs_features
+
+ def forward_res_refine_gs(self,
+ gs_hidden_features: Float[Tensor, "B Np Cp"],
+ query_points: Float[Tensor, "B Np_q 3"],
+ flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
+ additional_features: Optional[dict] = None,
+ debug: bool = False,
+ gs_raw_attr_list: list = None,
+ **kwargs):
+
+ batch_size = gs_hidden_features.shape[0]
+
+ query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features,
+ additional_features)
+
+ gs_model_list = []
+ for b in range(batch_size):
+ gs_model = self.gs_refine_net(query_gs_features[b], query_points[b], x_fine=None, gs_raw_attr=gs_raw_attr_list[b])
+ gs_model_list.append(gs_model)
+ return gs_model_list, query_points, flame_data, query_gs_features
+
+ def forward_animate_gs(self, gs_model_list, query_points, flame_data, c2w, intrinsic, height, width,
+ background_color, debug=False):
+ batch_size = len(gs_model_list)
+ out_list = []
+
+ for b in range(batch_size):
+ gs_model = gs_model_list[b]
+ query_pt = query_points[b]
+ animatable_gs_model_list: list[GaussianModel] = self.animate_gs_model(gs_model,
+ query_pt,
+ self.get_sing_batch_smpl_data(flame_data, b),
+ debug=debug)
+ assert len(animatable_gs_model_list) == c2w.shape[1]
+ out_list.append(self.forward_single_batch(
+ animatable_gs_model_list,
+ c2w[b],
+ intrinsic[b],
+ height, width,
+ background_color[b] if background_color is not None else None,
+ debug=debug))
+
+ out = defaultdict(list)
+ for out_ in out_list:
+ for k, v in out_.items():
+ out[k].append(v)
+ for k, v in out.items():
+ if isinstance(v[0], torch.Tensor):
+ out[k] = torch.stack(v, dim=0)
+ else:
+ out[k] = v
+
+ render_keys = ["comp_rgb", "comp_mask", "comp_depth"]
+ for key in render_keys:
+ out[key] = rearrange(out[key], "b v h w c -> b v c h w")
+
+ return out
+
+ def project_single_view_feats(self, img_vtx_ids, feats, nv, inter_feat=True):
+ b, h, w, k = img_vtx_ids.shape
+ c, ih, iw = feats.shape
+ vtx_ids = img_vtx_ids
+ if h != ih or w != iw:
+ if inter_feat:
+ feats = torch.nn.functional.interpolate(
+ rearrange(feats, "(b c) h w -> b c h w", b=1).float(), (h, w)
+ ).squeeze(0)
+ vtx_ids = rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).long().squeeze(1)
+ else:
+ vtx_ids = torch.nn.functional.interpolate(
+ rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).float(), (ih, iw), mode="nearest"
+ ).long().squeeze(1)
+ else:
+ vtx_ids = rearrange(vtx_ids, "b h w k -> (b k) h w", b=1).long()
+ vis_mask = vtx_ids > 0
+ vtx_ids = vtx_ids[vis_mask] # n
+ vtx_ids = repeat(vtx_ids, "n -> n c", c=c)
+
+ feats = repeat(feats, "c h w -> k h w c", k=k).to(vtx_ids.device)
+ feats = feats[vis_mask, :] # n, c
+
+ sums = torch.zeros((nv, c), dtype=feats.dtype, device=feats.device)
+ counts = torch.zeros((nv), dtype=torch.int64, device=feats.device)
+
+ sums.scatter_add_(0, vtx_ids, feats)
+ one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(feats.device)
+ counts.scatter_add_(0, vtx_ids[:, 0], one_hot)
+ clamp_counts = counts.clamp(min=1)
+ mean_feats = sums / clamp_counts.view(-1, 1)
+ return mean_feats
+
+ def forward(self,
+ gs_hidden_features: Float[Tensor, "B Np Cp"],
+ query_points: Float[Tensor, "B Np 3"],
+ flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
+ c2w: Float[Tensor, "B Nv 4 4"],
+ intrinsic: Float[Tensor, "B Nv 4 4"],
+ height,
+ width,
+ additional_features: Optional[Float[Tensor, "B C H W"]] = None,
+ background_color: Optional[Float[Tensor, "B Nv 3"]] = None,
+ debug: bool = False,
+ **kwargs):
+
+ # need shape_params of flame_data to get querty points and get "transform_mat_neutral_pose"
+ gs_model_list, query_points, flame_data, query_gs_features = self.forward_gs(gs_hidden_features, query_points, flame_data=flame_data,
+ additional_features=additional_features, debug=debug)
+
+ out = self.forward_animate_gs(gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug)
+
+ return out
+
+
+def test_head():
+ import cv2
+
+ human_model_path = "./pretrained_models/human_model_files"
+ device = "cuda"
+
+ from accelerate.utils import set_seed
+ set_seed(1234)
+
+ from lam.datasets.video_head import VideoHeadDataset
+ root_dir = "./train_data/vfhq_vhap/export"
+ meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json"
+ # root_dir = "./train_data/nersemble/export"
+ # meta_path = "./train_data/nersemble/label/valid_id_list1.json"
+ dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7,
+ render_image_res_low=512, render_image_res_high=512,
+ render_region_size=(512, 512), source_image_res=512,
+ enlarge_ratio=[0.8, 1.2],
+ debug=False)
+
+ data = dataset[0]
+
+ def get_flame_params(data):
+ flame_params = {}
+ flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\
+ 'rotation', 'neck_pose', 'eyes_pose', 'translation']
+ for k, v in data.items():
+ if k in flame_keys:
+ # print(k, v.shape)
+ flame_params[k] = data[k]
+ return flame_params
+
+ flame_data = get_flame_params(data)
+
+ flame_data_tmp = {}
+ for k, v in flame_data.items():
+ flame_data_tmp[k] = v.unsqueeze(0).to(device)
+ print(k, v.shape)
+ flame_data = flame_data_tmp
+
+ c2ws = data["c2ws"].unsqueeze(0).to(device)
+ intrs = data["intrs"].unsqueeze(0).to(device)
+ render_images = data["render_image"].numpy()
+ render_h = data["render_full_resolutions"][0, 0]
+ render_w= data["render_full_resolutions"][0, 1]
+ render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device)
+ print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs)
+
+ gs_render = GS3DRenderer(human_model_path=human_model_path, subdivide_num=2, smpl_type="flame",
+ feat_dim=64, query_dim=64, use_rgb=True, sh_degree=3, mlp_network_config=None,
+ xyz_offset_max_step=0.0001, expr_param_dim=10, shape_param_dim=10,
+ fix_opacity=True, fix_rotation=True, clip_scaling=0.001, add_teeth=False)
+ gs_render.to(device)
+
+ out = gs_render.forward(gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device),
+ query_points=None,
+ flame_data=flame_data,
+ c2w=c2ws,
+ intrinsic=intrs,
+ height=render_h,
+ width=render_w,
+ background_color=render_bg_colors,
+ debug=False)
+
+ os.makedirs("./debug_vis/gs_render", exist_ok=True)
+ for k, v in out.items():
+ if k == "comp_rgb_bg":
+ print("comp_rgb_bg", v)
+ continue
+ for b_idx in range(len(v)):
+ if k == "3dgs":
+ for v_idx in range(len(v[b_idx])):
+ v[b_idx][v_idx].save_ply(f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply")
+ continue
+ for v_idx in range(v.shape[1]):
+ save_path = os.path.join("./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg")
+ if "normal" in k:
+ img = ((v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2. * 255).astype(np.uint8)
+ else:
+ img = (v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
+ print(v[b_idx, v_idx].shape, img.shape, save_path)
+ if "mask" in k:
+ render_img = render_images[v_idx].transpose(1, 2, 0) * 255
+ blend_img = (render_images[v_idx].transpose(1, 2, 0) * 255 * 0.5 + np.tile(img, (1, 1, 3)) * 0.5).clip(0, 255).astype(np.uint8)
+ cv2.imwrite(save_path, np.hstack([np.tile(img, (1, 1, 3)), render_img.astype(np.uint8), blend_img])[:, :, (2, 1, 0)])
+ else:
+ print(save_path, k)
+ cv2.imwrite(save_path, img)
+
+
+
+if __name__ == "__main__":
+ test_head()
diff --git a/lam/models/rendering/utils/__init__.py b/lam/models/rendering/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34
--- /dev/null
+++ b/lam/models/rendering/utils/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
diff --git a/lam/models/rendering/utils/math_utils.py b/lam/models/rendering/utils/math_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af
--- /dev/null
+++ b/lam/models/rendering/utils/math_utils.py
@@ -0,0 +1,118 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
diff --git a/lam/models/rendering/utils/mesh_utils.py b/lam/models/rendering/utils/mesh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced91448bc48e3ede1991c72ff1aa4b80343ab78
--- /dev/null
+++ b/lam/models/rendering/utils/mesh_utils.py
@@ -0,0 +1,384 @@
+import os
+import cv2
+import math
+import torch
+import numpy as np
+import torch.nn.functional as F
+from collections import OrderedDict
+from scipy.ndimage import morphology
+from skimage.io import imsave
+
+
+def dict2obj(d):
+ if isinstance(d, list):
+ d = [dict2obj(x) for x in d]
+ if not isinstance(d, dict):
+ return d
+
+ class C(object):
+ pass
+
+ o = C()
+ for k in d:
+ o.__dict__[k] = dict2obj(d[k])
+ return o
+
+
+def check_mkdir(path):
+ if not os.path.exists(path):
+ print('making %s' % path)
+ os.makedirs(path)
+
+
+def l2_distance(verts1, verts2):
+ return torch.sqrt(((verts1 - verts2) ** 2).sum(2)).mean(1).mean()
+
+
+def quat2mat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
+ 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
+ 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def batch_rodrigues(theta):
+ # theta N x 3
+ batch_size = theta.shape[0]
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
+
+ return quat2mat(quat)
+
+
+def batch_orth_proj(X, camera):
+ '''
+ X is N x num_points x 3
+ '''
+ camera = camera.clone().view(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ X_trans = torch.cat([X_trans, X[:, :, 2:]], 2)
+ shape = X_trans.shape
+ # Xn = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
+ Xn = (camera[:, :, 0:1] * X_trans)
+ return Xn
+
+
+def batch_persp_proj(vertices, cam, f, t, orig_size=256, eps=1e-9):
+ '''
+ Calculate projective transformation of vertices given a projection matrix
+ Input parameters:
+ f: torch tensor of focal length
+ t: batch_size * 1 * 3 xyz translation in world coordinate
+ K: batch_size * 3 * 3 intrinsic camera matrix
+ R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters
+ dist_coeffs: vector of distortion coefficients
+ orig_size: original size of image captured by the camera
+ Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in
+ pixels and z is the depth
+ '''
+ device = vertices.device
+
+ K = torch.tensor([f, 0., cam['c'][0], 0., f, cam['c'][1], 0., 0., 1.]).view(3, 3)[None, ...].repeat(
+ vertices.shape[0], 1).to(device)
+ R = batch_rodrigues(cam['r'][None, ...].repeat(vertices.shape[0], 1)).to(device)
+ dist_coeffs = cam['k'][None, ...].repeat(vertices.shape[0], 1).to(device)
+
+ vertices = torch.matmul(vertices, R.transpose(2, 1)) + t
+ x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2]
+ x_ = x / (z + eps)
+ y_ = y / (z + eps)
+
+ # Get distortion coefficients from vector
+ k1 = dist_coeffs[:, None, 0]
+ k2 = dist_coeffs[:, None, 1]
+ p1 = dist_coeffs[:, None, 2]
+ p2 = dist_coeffs[:, None, 3]
+ k3 = dist_coeffs[:, None, 4]
+
+ # we use x_ for x' and x__ for x'' etc.
+ r = torch.sqrt(x_ ** 2 + y_ ** 2)
+ x__ = x_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + 2 * p1 * x_ * y_ + p2 * (r ** 2 + 2 * x_ ** 2)
+ y__ = y_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + p1 * (r ** 2 + 2 * y_ ** 2) + 2 * p2 * x_ * y_
+ vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1)
+ vertices = torch.matmul(vertices, K.transpose(1, 2))
+ u, v = vertices[:, :, 0], vertices[:, :, 1]
+ v = orig_size - v
+ # map u,v from [0, img_size] to [-1, 1] to be compatible with the renderer
+ u = 2 * (u - orig_size / 2.) / orig_size
+ v = 2 * (v - orig_size / 2.) / orig_size
+ vertices = torch.stack([u, v, z], dim=-1)
+
+ return vertices
+
+
+def face_vertices(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of faces, 3, 3]
+ """
+ assert (vertices.ndimension() == 3)
+ assert (faces.ndimension() == 3)
+ assert (vertices.shape[0] == faces.shape[0])
+ assert (vertices.shape[2] == 3)
+ assert (faces.shape[2] == 3)
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
+
+def vertex_normals(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of vertices, 3]
+ """
+ assert (vertices.ndimension() == 3)
+ assert (faces.ndimension() == 3)
+ assert (vertices.shape[0] == faces.shape[0])
+ assert (vertices.shape[2] == 3)
+ assert (faces.shape[2] == 3)
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ normals = torch.zeros(bs * nv, 3).to(device)
+
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces
+ vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()]
+
+ faces = faces.view(-1, 3)
+ vertices_faces = vertices_faces.view(-1, 3, 3)
+
+ normals.index_add_(0, faces[:, 1].long(),
+ torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1]))
+ normals.index_add_(0, faces[:, 2].long(),
+ torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2]))
+ normals.index_add_(0, faces[:, 0].long(),
+ torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0]))
+
+ normals = F.normalize(normals, eps=1e-6, dim=1)
+ normals = normals.reshape((bs, nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return normals
+
+
+def tensor_vis_landmarks(images, landmarks, gt_landmarks=None, color='g', isScale=True):
+ # visualize landmarks
+ vis_landmarks = []
+ images = images.cpu().numpy()
+ predicted_landmarks = landmarks.detach().cpu().numpy()
+ if gt_landmarks is not None:
+ gt_landmarks_np = gt_landmarks.detach().cpu().numpy()
+ for i in range(images.shape[0]):
+ image = images[i]
+ image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy();
+ image = (image * 255)
+ if isScale:
+ predicted_landmark = predicted_landmarks[i] * image.shape[0] / 2 + image.shape[0] / 2
+ else:
+ predicted_landmark = predicted_landmarks[i]
+
+ if predicted_landmark.shape[0] == 68:
+ image_landmarks = plot_kpts(image, predicted_landmark, color)
+ if gt_landmarks is not None:
+ image_landmarks = plot_verts(image_landmarks,
+ gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r')
+ else:
+ image_landmarks = plot_verts(image, predicted_landmark, color)
+ if gt_landmarks is not None:
+ image_landmarks = plot_verts(image_landmarks,
+ gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r')
+
+ vis_landmarks.append(image_landmarks)
+
+ vis_landmarks = np.stack(vis_landmarks)
+ vis_landmarks = torch.from_numpy(
+ vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255. # , dtype=torch.float32)
+ return vis_landmarks
+
+
+end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1
+def plot_kpts(image, kpts, color = 'r'):
+ ''' Draw 68 key points
+ Args:
+ image: the input image
+ kpt: (68, 3).
+ '''
+ if color == 'r':
+ c = (255, 0, 0)
+ elif color == 'g':
+ c = (0, 255, 0)
+ elif color == 'b':
+ c = (255, 0, 0)
+ image = image.copy()
+ kpts = kpts.copy()
+
+ for i in range(kpts.shape[0]):
+ st = kpts[i, :2]
+ if kpts.shape[1]==4:
+ if kpts[i, 3] > 0.5:
+ c = (0, 255, 0)
+ else:
+ c = (0, 0, 255)
+ image = cv2.circle(image,(st[0], st[1]), 1, c, 2)
+ if i in end_list:
+ continue
+ ed = kpts[i + 1, :2]
+ image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), (255, 255, 255), 1)
+
+ return image
+
+
+def save_obj(filename, vertices, faces, textures=None, uvcoords=None, uvfaces=None, texture_type='surface'):
+ assert vertices.ndimension() == 2
+ assert faces.ndimension() == 2
+ assert texture_type in ['surface', 'vertex']
+ # assert texture_res >= 2
+
+ if textures is not None and texture_type == 'surface':
+ textures =textures.detach().cpu().numpy().transpose(1,2,0)
+ filename_mtl = filename[:-4] + '.mtl'
+ filename_texture = filename[:-4] + '.png'
+ material_name = 'material_1'
+ # texture_image, vertices_textures = create_texture_image(textures, texture_res)
+ texture_image = textures
+ texture_image = texture_image.clip(0, 1)
+ texture_image = (texture_image * 255).astype('uint8')
+ imsave(filename_texture, texture_image)
+
+ faces = faces.detach().cpu().numpy()
+
+ with open(filename, 'w') as f:
+ f.write('# %s\n' % os.path.basename(filename))
+ f.write('#\n')
+ f.write('\n')
+
+ if textures is not None and texture_type != "vertex":
+ f.write('mtllib %s\n\n' % os.path.basename(filename_mtl))
+
+ if textures is not None and texture_type == 'vertex':
+ for vertex, color in zip(vertices, textures):
+ f.write('v %.8f %.8f %.8f %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2],
+ color[0], color[1], color[2]))
+ f.write('\n')
+ else:
+ for vertex in vertices:
+ f.write('v %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2]))
+ f.write('\n')
+
+ if textures is not None and texture_type == 'surface':
+ for vertex in uvcoords.reshape((-1, 2)):
+ f.write('vt %.8f %.8f\n' % (vertex[0], vertex[1]))
+ f.write('\n')
+
+ f.write('usemtl %s\n' % material_name)
+ for i, face in enumerate(faces):
+ f.write('f %d/%d %d/%d %d/%d\n' % (
+ face[0] + 1, uvfaces[i,0]+1, face[1] + 1, uvfaces[i,1]+1, face[2] + 1, uvfaces[i,2]+1))
+ f.write('\n')
+ else:
+ for face in faces:
+ f.write('f %d %d %d\n' % (face[0] + 1, face[1] + 1, face[2] + 1))
+
+ if textures is not None and texture_type == 'surface':
+ with open(filename_mtl, 'w') as f:
+ f.write('newmtl %s\n' % material_name)
+ f.write('map_Kd %s\n' % os.path.basename(filename_texture))
+
+
+def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return torch.sum(x*y, -1, keepdim=True)
+
+def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
+ return 2*dot(x, n)*n - x
+
+def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+ return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
+
+def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+ return x / length(x, eps)
+
+def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
+ return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
+
+def compute_face_normals(verts, faces):
+ i0 = faces[..., 0].long()
+ i1 = faces[..., 1].long()
+ i2 = faces[..., 2].long()
+
+ v0 = verts[..., i0, :]
+ v1 = verts[..., i1, :]
+ v2 = verts[..., i2, :]
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+ return face_normals
+
+def compute_face_orientation(verts, faces, return_scale=False):
+ i0 = faces[..., 0].long()
+ i1 = faces[..., 1].long()
+ i2 = faces[..., 2].long()
+
+ v0 = verts[..., i0, :]
+ v1 = verts[..., i1, :]
+ v2 = verts[..., i2, :]
+
+ a0 = safe_normalize(v1 - v0)
+ a1 = safe_normalize(torch.cross(a0, v2 - v0, dim=-1))
+ a2 = -safe_normalize(torch.cross(a1, a0, dim=-1)) # will have artifacts without negation
+
+ orientation = torch.cat([a0[..., None], a1[..., None], a2[..., None]], dim=-1)
+
+ if return_scale:
+ s0 = length(v1 - v0)
+ s1 = dot(a2, (v2 - v0)).abs()
+ scale = (s0 + s1) / 2
+ else:
+ scale = None
+ return orientation, scale
+
+def compute_vertex_normals(verts, faces):
+ i0 = faces[..., 0].long()
+ i1 = faces[..., 1].long()
+ i2 = faces[..., 2].long()
+
+ v0 = verts[..., i0, :]
+ v1 = verts[..., i1, :]
+ v2 = verts[..., i2, :]
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+ v_normals = torch.zeros_like(verts)
+ N = verts.shape[0]
+ v_normals.scatter_add_(1, i0[..., None].repeat(N, 1, 3), face_normals)
+ v_normals.scatter_add_(1, i1[..., None].repeat(N, 1, 3), face_normals)
+ v_normals.scatter_add_(1, i2[..., None].repeat(N, 1, 3), face_normals)
+
+ v_normals = torch.where(dot(v_normals, v_normals) > 1e-20, v_normals, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
+ v_normals = safe_normalize(v_normals)
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(v_normals))
+ return v_normals
\ No newline at end of file
diff --git a/lam/models/rendering/utils/point_utils.py b/lam/models/rendering/utils/point_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f701308d028056a4ef6ce145bd6f4b8983c02e6a
--- /dev/null
+++ b/lam/models/rendering/utils/point_utils.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import os, cv2
+import matplotlib.pyplot as plt
+import math
+
+def depths_to_points(view, depthmap):
+ c2w = (view.world_view_transform.T).inverse()
+ if hasattr(view, "image_width"):
+ W, H = view.image_width, view.image_height
+ else:
+ W, H = view.width, view.height
+ ndc2pix = torch.tensor([
+ [W / 2, 0, 0, (W) / 2],
+ [0, H / 2, 0, (H) / 2],
+ [0, 0, 0, 1]]).float().cuda().T
+ projection_matrix = c2w.T @ view.full_proj_transform
+ intrins = (projection_matrix @ ndc2pix)[:3,:3].T
+
+ grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy')
+ points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3)
+ rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T
+ rays_o = c2w[:3,3]
+ points = depthmap.reshape(-1, 1) * rays_d + rays_o
+ return points
+
+def depth_to_normal(view, depth):
+ """
+ view: view camera
+ depth: depthmap
+ """
+ points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3)
+ output = torch.zeros_like(points)
+ dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
+ dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
+ normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
+ output[1:-1, 1:-1, :] = normal_map
+ return output
\ No newline at end of file
diff --git a/lam/models/rendering/utils/renderer.py b/lam/models/rendering/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a978494fbbd980920dd8236f425bd52656f6baa
--- /dev/null
+++ b/lam/models/rendering/utils/renderer.py
@@ -0,0 +1,302 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He in 2023-2024.
+# The modifications are subject to the same license as the original.
+
+
+"""
+The renderer is a module that takes in rays, decides where to sample along each
+ray, and computes pixel colors using the volume rendering equation.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import math_utils
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+
+ Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [0, 1, 0],
+ [1, 0, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+ N, M, C = coordinates.shape
+ n_planes, _, _ = planes.shape
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+ projections = torch.bmm(coordinates, inv_planes)
+ return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ coordinates = (2/box_warp) * coordinates # add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+ return sampled_features
+
+class ImportanceRenderer(torch.nn.Module):
+ """
+ Modified original version to filter out-of-box samples as TensoRF does.
+
+ Reference:
+ TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
+ """
+ def __init__(self):
+ super().__init__()
+ self.activation_factory = self._build_activation_factory()
+ self.ray_marcher = MipRayMarcher2(self.activation_factory)
+ self.plane_axes = generate_planes()
+
+ def _build_activation_factory(self):
+ def activation_factory(options: dict):
+ if options['clamp_mode'] == 'softplus':
+ return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "Renderer only supports `clamp_mode`=`softplus`!"
+ return activation_factory
+
+ def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
+ planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
+ """
+ Additional filtering is applied to filter out-of-box samples.
+ Modifications made by Zexin He.
+ """
+
+ # context related variables
+ batch_size, num_rays, samples_per_ray, _ = depths.shape
+ device = depths.device
+
+ # define sample points with depths
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+ # filter out-of-box samples
+ mask_inbox = \
+ (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
+ (sample_coordinates <= rendering_options['sampler_bbox_max'])
+ mask_inbox = mask_inbox.all(-1)
+
+ # forward model according to all samples
+ _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
+
+ # set out-of-box samples to zeros(rgb) & -inf(sigma)
+ SAFE_GUARD = 8
+ DATA_TYPE = _out['sigma'].dtype
+ colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
+ densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
+ colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
+
+ # reshape back
+ colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
+ densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
+
+ return colors_pass, densities_pass
+
+ def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bg_colors=None):
+ # self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+ if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
+ ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
+ is_ray_valid = ray_end > ray_start
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+ else:
+ # Create stratified depth samples
+ depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+
+ # Coarse Pass
+ colors_coarse, densities_coarse = self._forward_pass(
+ depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ if N_importance > 0:
+ _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors)
+
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+
+ colors_fine, densities_fine = self._forward_pass(
+ depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+ depths_fine, colors_fine, densities_fine)
+
+ # Aggregate
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bg_colors=bg_colors)
+ else:
+ rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors)
+
+ return rgb_final, depth_final, weights.sum(2)
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+ plane_axes = self.plane_axes.to(planes.device)
+ sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+ out = decoder(sampled_features, sample_directions)
+ if options.get('density_noise', 0) > 0:
+ out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+ return out
+
+ def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
+ out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
+ out['sigma'] = self.activation_factory(options)(out['sigma'])
+ return out
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+ all_depths = torch.cat([depths1, depths2], dim = -2)
+ all_colors = torch.cat([colors1, colors2], dim = -2)
+ all_densities = torch.cat([densities1, densities2], dim = -2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities
+
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = 1/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+ else:
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds-1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[...,1]-cdf_g[...,0]
+ denom[denom= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
\ No newline at end of file
diff --git a/lam/models/rendering/utils/typing.py b/lam/models/rendering/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee9f967c21f94db1ad939d7dead156d86748752
--- /dev/null
+++ b/lam/models/rendering/utils/typing.py
@@ -0,0 +1,40 @@
+"""
+This module contains type annotations for the project, using
+1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
+2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
+
+Two types of typing checking can be used:
+1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
+2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
+"""
+
+# Basic types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ NamedTuple,
+ NewType,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+# Tensor dtype
+# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
+from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
+
+# Config type
+from omegaconf import DictConfig
+
+# PyTorch Tensor type
+from torch import Tensor
+
+# Runtime type checking decorator
+from typeguard import typechecked as typechecker
diff --git a/lam/models/rendering/utils/utils.py b/lam/models/rendering/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f9d29814d447ef553001c9ad5bfd5210d3d9e85
--- /dev/null
+++ b/lam/models/rendering/utils/utils.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+from lam.models.rendering.utils.typing import *
+
+def get_activation(name):
+ if name is None:
+ return lambda x: x
+ name = name.lower()
+ if name == "none":
+ return lambda x: x
+ elif name == "lin2srgb":
+ return lambda x: torch.where(
+ x > 0.0031308,
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
+ 12.92 * x,
+ ).clamp(0.0, 1.0)
+ elif name == "exp":
+ return lambda x: torch.exp(x)
+ elif name == "shifted_exp":
+ return lambda x: torch.exp(x - 1.0)
+ elif name == "trunc_exp":
+ return trunc_exp
+ elif name == "shifted_trunc_exp":
+ return lambda x: trunc_exp(x - 1.0)
+ elif name == "sigmoid":
+ return lambda x: torch.sigmoid(x)
+ elif name == "tanh":
+ return lambda x: torch.tanh(x)
+ elif name == "shifted_softplus":
+ return lambda x: F.softplus(x - 1.0)
+ elif name == "scale_-11_01":
+ return lambda x: x * 0.5 + 0.5
+ else:
+ try:
+ return getattr(F, name)
+ except AttributeError:
+ raise ValueError(f"Unknown activation function: {name}")
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_out: int,
+ n_neurons: int,
+ n_hidden_layers: int,
+ activation: str = "relu",
+ output_activation: Optional[str] = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ layers = [
+ self.make_linear(
+ dim_in, n_neurons, is_first=True, is_last=False, bias=bias
+ ),
+ self.make_activation(activation),
+ ]
+ for i in range(n_hidden_layers - 1):
+ layers += [
+ self.make_linear(
+ n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
+ ),
+ self.make_activation(activation),
+ ]
+ layers += [
+ self.make_linear(
+ n_neurons, dim_out, is_first=False, is_last=True, bias=bias
+ )
+ ]
+ self.layers = nn.Sequential(*layers)
+ self.output_activation = get_activation(output_activation)
+
+ def forward(self, x):
+ x = self.layers(x)
+ x = self.output_activation(x)
+ return x
+
+ def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
+ return layer
+
+ def make_activation(self, activation):
+ if activation == "relu":
+ return nn.ReLU(inplace=True)
+ elif activation == "silu":
+ return nn.SiLU(inplace=True)
+ else:
+ raise NotImplementedError
+
+
+class _TruncExp(Function): # pylint: disable=abstract-method
+ # Implementation from torch-ngp:
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, x): # pylint: disable=arguments-differ
+ ctx.save_for_backward(x)
+ return torch.exp(x)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, g): # pylint: disable=arguments-differ
+ x = ctx.saved_tensors[0]
+ return g * torch.exp(torch.clamp(x, max=15))
+
+
+trunc_exp = _TruncExp.apply
\ No newline at end of file
diff --git a/lam/models/rendering/utils/uv_utils.py b/lam/models/rendering/utils/uv_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..92511ecc6664b8e6485a4fcd3bf3d9c55fcf3c8a
--- /dev/null
+++ b/lam/models/rendering/utils/uv_utils.py
@@ -0,0 +1,366 @@
+import torch
+import numpy as np
+import math
+import torch.nn as nn
+
+from pytorch3d.structures import Meshes
+from pytorch3d.io import load_obj
+from pytorch3d.renderer.mesh import rasterize_meshes
+from pytorch3d.ops import mesh_face_areas_normals
+
+#-------------------------------------------------------------------------------#
+
+def gen_tritex(vt: np.ndarray, vi: np.ndarray, vti: np.ndarray, texsize: int):
+ """
+ Copied from MVP
+ Create 3 texture maps containing the vertex indices, texture vertex
+ indices, and barycentric coordinates
+
+ Parameters
+ ----------
+ vt: uv coordinates of texels
+ vi: triangle list mapping into vertex positions
+ vti: triangle list mapping into texel coordinates
+ texsize: Size of the generated maps
+ """
+ # vt = ((vt + 1. ) / 2.)[:, :2]
+ vt = vt[:, :2]
+
+ vt = np.array(vt, dtype=np.float32)
+ vi = np.array(vi, dtype=np.int32)
+ vti = np.array(vti, dtype=np.int32)
+ ntris = vi.shape[0]
+
+ texu, texv = np.meshgrid(
+ (np.arange(texsize) + 0.5) / texsize,
+ (np.arange(texsize) + 0.5) / texsize)
+ texuv = np.stack((texu, texv), axis=-1)
+
+ vt = vt[vti]
+
+ viim = np.zeros((texsize, texsize, 3), dtype=np.int32)
+ vtiim = np.zeros((texsize, texsize, 3), dtype=np.int32)
+ baryim = np.zeros((texsize, texsize, 3), dtype=np.float32)
+
+ for i in list(range(ntris))[::-1]:
+ bbox = (
+ max(0, int(min(vt[i, 0, 0], min(vt[i, 1, 0], vt[i, 2, 0])) * texsize) - 1),
+ min(texsize, int(max(vt[i, 0, 0], max(vt[i, 1, 0], vt[i, 2, 0])) * texsize) + 2),
+ max(0, int(min(vt[i, 0, 1], min(vt[i, 1, 1], vt[i, 2, 1])) * texsize) - 1),
+ min(texsize, int(max(vt[i, 0, 1], max(vt[i, 1, 1], vt[i, 2, 1])) * texsize) + 2))
+ v0 = vt[None, None, i, 1, :] - vt[None, None, i, 0, :]
+ v1 = vt[None, None, i, 2, :] - vt[None, None, i, 0, :]
+ v2 = texuv[bbox[2]:bbox[3], bbox[0]:bbox[1], :] - vt[None, None, i, 0, :]
+ d00 = np.sum(v0 * v0, axis=-1)
+ d01 = np.sum(v0 * v1, axis=-1)
+ d11 = np.sum(v1 * v1, axis=-1)
+ d20 = np.sum(v2 * v0, axis=-1)
+ d21 = np.sum(v2 * v1, axis=-1)
+ denom = d00 * d11 - d01 * d01
+
+ if denom != 0.:
+ baryv = (d11 * d20 - d01 * d21) / denom
+ baryw = (d00 * d21 - d01 * d20) / denom
+ baryu = 1. - baryv - baryw
+
+ baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
+ ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
+ np.stack((baryu, baryv, baryw), axis=-1),
+ baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
+ viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
+ ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
+ np.stack((vi[i, 0], vi[i, 1], vi[i, 2]), axis=-1),
+ viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
+ vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
+ ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
+ np.stack((vti[i, 0], vti[i, 1], vti[i, 2]), axis=-1),
+ vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
+
+ return torch.LongTensor(viim), torch.Tensor(vtiim), torch.Tensor(baryim)
+
+
+# modified from https://github.com/facebookresearch/pytorch3d
+class Pytorch3dRasterizer(nn.Module):
+ def __init__(self, image_size=224):
+ """
+ use fixed raster_settings for rendering faces
+ """
+ super().__init__()
+ raster_settings = {
+ 'image_size': image_size,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'bin_size': None,
+ 'max_faces_per_bin': None,
+ 'perspective_correct': False,
+ 'cull_backfaces': True
+ }
+ # raster_settings = dict2obj(raster_settings)
+ self.raster_settings = raster_settings
+
+ def forward(self, vertices, faces, h=None, w=None):
+ fixed_vertices = vertices.clone()
+ fixed_vertices[...,:2] = -fixed_vertices[...,:2]
+ raster_settings = self.raster_settings
+ if h is None and w is None:
+ image_size = raster_settings['image_size']
+ else:
+ image_size = [h, w]
+ if h>w:
+ fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w
+ else:
+ fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h
+
+ meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
+ meshes_screen,
+ image_size=image_size,
+ blur_radius=raster_settings['blur_radius'],
+ faces_per_pixel=raster_settings['faces_per_pixel'],
+ bin_size=raster_settings['bin_size'],
+ max_faces_per_bin=raster_settings['max_faces_per_bin'],
+ perspective_correct=raster_settings['perspective_correct'],
+ cull_backfaces=raster_settings['cull_backfaces']
+ )
+
+ return pix_to_face, bary_coords
+
+#-------------------------------------------------------------------------------#
+
+# borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py
+def face_vertices(vertices, faces):
+ """
+ Indexing the coordinates of the three vertices on each face.
+
+ Args:
+ vertices: [bs, V, 3]
+ faces: [bs, F, 3]
+
+ Return:
+ face_to_vertices: [bs, F, 3, 3]
+ """
+ assert (vertices.ndimension() == 3)
+ assert (faces.ndimension() == 3)
+ # assert (vertices.shape[0] == faces.shape[0])
+ assert (vertices.shape[2] == 3)
+ assert (faces.shape[2] == 3)
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
+def uniform_sampling_barycoords(
+ num_points: int,
+ tex_coord: torch.Tensor,
+ uv_faces: torch.Tensor,
+ d_size: float=1.0,
+ strict: bool=False,
+ use_mask: bool=True,
+ ):
+ """
+ Uniformly sampling barycentric coordinates using the rasterizer.
+
+ Args:
+ num_points: int sampling points number
+ tex_coord: [5150, 2] UV coords for each vert
+ uv_faces: [F,3] UV faces to UV coords index
+ d_size: const to control sampling points number
+ use_mask: use mask to mask valid points
+ Returns:
+ face_index [num_points] save which face each bary_coords belongs to
+ bary_coords [num_points, 3]
+ """
+
+ uv_size = int(math.sqrt(num_points) * d_size)
+ uv_rasterizer = Pytorch3dRasterizer(uv_size)
+
+ tex_coord = tex_coord[None, ...]
+ uv_faces = uv_faces[None, ...]
+
+ tex_coord_ = torch.cat([tex_coord, tex_coord[:,:,0:1]*0.+1.], -1)
+ tex_coord_ = tex_coord_ * 2 - 1
+ tex_coord_[...,1] = - tex_coord_[...,1]
+
+ pix_to_face, bary_coords = uv_rasterizer(tex_coord_.expand(1, -1, -1), uv_faces.expand(1, -1, -1))
+ mask = (pix_to_face == -1)
+
+ if use_mask:
+ face_index = pix_to_face[~mask]
+ bary_coords = bary_coords[~mask]
+ else:
+ return pix_to_face, bary_coords
+
+ cur_n = face_index.shape[0]
+
+ # fix sampling number to num_points
+ if strict:
+ if cur_n < num_points:
+ pad_size = num_points - cur_n
+ new_face_index = face_index[torch.randint(0, cur_n, (pad_size,))]
+ new_bary_coords = torch.rand((pad_size, 3), device=bary_coords.device)
+ new_bary_coords = new_bary_coords / new_bary_coords.sum(dim=-1, keepdim=True)
+ face_index = torch.cat([face_index, new_face_index], dim=0)
+ bary_coords = torch.cat([bary_coords, new_bary_coords], dim=0)
+ elif cur_n > num_points:
+ face_index = face_index[:num_points]
+ bary_coords = bary_coords[:num_points]
+
+ return face_index, bary_coords
+
+def random_sampling_barycoords(
+ num_points: int,
+ vertices: torch.Tensor,
+ faces: torch.Tensor
+ ):
+ """
+ Randomly sampling barycentric coordinates using the rasterizer.
+
+ Args:
+ num_points: int sampling points number
+ vertices: [V, 3]
+ faces: [F,3]
+ Returns:
+ face_index [num_points] save which face each bary_coords belongs to
+ bary_coords [num_points, 3]
+ """
+
+ areas, _ = mesh_face_areas_normals(vertices.squeeze(0), faces)
+
+ g1 = torch.Generator(device=vertices.device)
+ g1.manual_seed(0)
+
+ face_index = areas.multinomial(
+ num_points, replacement=True, generator=g1
+ ) # (N, num_samples)
+
+ uvw = torch.rand((face_index.shape[0], 3), device=vertices.device)
+ bary_coords = uvw / uvw.sum(dim=-1, keepdim=True)
+
+ return face_index, bary_coords
+
+def reweight_verts_by_barycoords(
+ verts: torch.Tensor,
+ faces: torch.Tensor,
+ face_index: torch.Tensor,
+ bary_coords: torch.Tensor,
+ ):
+ """
+ Reweights the vertices based on the barycentric coordinates for each face.
+
+ Args:
+ verts: [bs, V, 3].
+ faces: [F, 3]
+ face_index: [N].
+ bary_coords: [N, 3].
+
+ Returns:
+ Reweighted vertex positions of shape [bs, N, 3].
+ """
+
+ # index attributes by face
+ B = verts.shape[0]
+
+ face_verts = face_vertices(verts, faces.expand(B, -1, -1)) # [1, F, 3, 3]
+ # gather idnex for every splat
+ N = face_index.shape[0]
+ face_index_3 = face_index.view(1, N, 1, 1).expand(B, N, 3, 3)
+ position_vals = face_verts.gather(1, face_index_3)
+ # reweight
+ position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2)
+
+ return position_vals
+
+def reweight_uvcoords_by_barycoords(
+ uvcoords: torch.Tensor,
+ uvfaces: torch.Tensor,
+ face_index: torch.Tensor,
+ bary_coords: torch.Tensor,
+ ):
+ """
+ Reweights the UV coordinates based on the barycentric coordinates for each face.
+
+ Args:
+ uvcoords: [bs, V', 2].
+ uvfaces: [F, 3].
+ face_index: [N].
+ bary_coords: [N, 3].
+
+ Returns:
+ Reweighted UV coordinates, shape [bs, N, 2].
+ """
+
+ # homogeneous coordinates
+ num_v = uvcoords.shape[0]
+ uvcoords = torch.cat([uvcoords, torch.ones((num_v, 1)).to(uvcoords.device)], dim=1)
+ # index attributes by face
+ uvcoords = uvcoords[None, ...]
+ face_verts = face_vertices(uvcoords, uvfaces.expand(1, -1, -1)) # [1, F, 3, 3]
+ # gather idnex for every splat
+ N = face_index.shape[0]
+ face_index_3 = face_index.view(1, N, 1, 1).expand(1, N, 3, 3)
+ position_vals = face_verts.gather(1, face_index_3)
+ # reweight
+ position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2)
+
+ return position_vals
+
+# modified from https://github.com/computational-imaging/GSM/blob/main/main/gsm/deformer/util.py
+def get_shell_verts_from_base(
+ template_verts: torch.Tensor,
+ template_faces: torch.Tensor,
+ offset_len: float,
+ num_shells: int,
+ deflat = False,
+ ):
+ """
+ Generates shell vertices by offsetting the original mesh's vertices along their normals.
+
+ Args:
+ template_verts: [bs, V, 3].
+ template_faces: [F, 3].
+ offset_len: Positive number specifying the offset length for generating shells.
+ num_shells: The number of shells to generate.
+ deflat: If True, performs a deflation process. Defaults to False.
+
+ Returns:
+ shell verts: [bs, num_shells, n, 3]
+ """
+ out_offset_len = offset_len
+
+ if deflat:
+ in_offset_len = offset_len
+
+ batch_size = template_verts.shape[0]
+ mesh = Meshes(
+ verts=template_verts, faces=template_faces[None].repeat(batch_size, 1, 1)
+ )
+ # bs, n, 3
+ vertex_normal = mesh.verts_normals_padded()
+ # only for inflating
+
+ if deflat:
+ n_inflated_shells = num_shells//2 + 1
+ else:
+ n_inflated_shells = num_shells
+
+ linscale = torch.linspace(
+ out_offset_len,
+ 0,
+ n_inflated_shells,
+ device=template_verts.device,
+ dtype=template_verts.dtype,
+ )
+ offset = linscale.reshape(1,n_inflated_shells, 1, 1) * vertex_normal[:, None]
+
+ if deflat:
+ linscale = torch.linspace(0, -in_offset_len, num_shells - n_inflated_shells + 1, device=template_verts.device, dtype=template_verts.dtype)[1:]
+ offset_in = linscale.reshape(1, -1, 1, 1) * vertex_normal[:, None]
+ offset = torch.cat([offset, offset_in], dim=1)
+
+ verts = template_verts[:, None] + offset
+ assert verts.isfinite().all()
+ return verts
\ No newline at end of file
diff --git a/lam/models/rendering/utils/vis_utils.py b/lam/models/rendering/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab20328be4342ee916fd64486534101e3b33684
--- /dev/null
+++ b/lam/models/rendering/utils/vis_utils.py
@@ -0,0 +1,377 @@
+import os
+import cv2
+import numpy as np
+from mpl_toolkits.mplot3d import Axes3D
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import os
+import sys
+os.environ["PYOPENGL_PLATFORM"] = "egl"
+from pytorch3d.structures import Meshes, Pointclouds
+from pytorch3d.renderer import (
+ PointLights,
+ DirectionalLights,
+ PerspectiveCameras,
+ Materials,
+ SoftPhongShader,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRendererWithFragments,
+ MeshRasterizer,
+ TexturesVertex,
+ PointsRasterizationSettings,
+ PointsRenderer,
+ PointsRasterizer,
+ AlphaCompositor
+)
+import torch
+import torch.nn as nn
+
+def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1):
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
+ cmap = plt.get_cmap('rainbow')
+ colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
+ colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
+
+ # Perform the drawing on a copy of the image, to allow for blending.
+ kp_mask = np.copy(img)
+
+ # Draw the keypoints.
+ for l in range(len(kps_lines)):
+ i1 = kps_lines[l][0]
+ i2 = kps_lines[l][1]
+ p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32)
+ p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32)
+ if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
+ cv2.line(
+ kp_mask, p1, p2,
+ color=colors[l], thickness=2, lineType=cv2.LINE_AA)
+ if kps[2, i1] > kp_thresh:
+ cv2.circle(
+ kp_mask, p1,
+ radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
+ if kps[2, i2] > kp_thresh:
+ cv2.circle(
+ kp_mask, p2,
+ radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
+
+ # Blend the keypoints.
+ return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
+
+def vis_keypoints(img, kps, alpha=1):
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
+ cmap = plt.get_cmap('rainbow')
+ colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)]
+ colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
+
+ # Perform the drawing on a copy of the image, to allow for blending.
+ kp_mask = np.copy(img)
+
+ # Draw the keypoints.
+ for i in range(len(kps)):
+ p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32)
+ cv2.circle(kp_mask, p, radius=3, color=colors[i], thickness=-1, lineType=cv2.LINE_AA)
+
+ # Blend the keypoints.
+ return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
+
+
+def render_mesh(mesh, face, cam_param, bkg, blend_ratio=1.0, return_bg_mask=False, R=None, T=None, return_fragments=False):
+ mesh = mesh.cuda()[None,:,:]
+ face = torch.LongTensor(face.astype(np.int64)).cuda()[None,:,:]
+ cam_param = {k: v.cuda()[None,:] for k,v in cam_param.items()}
+ render_shape = (bkg.shape[0], bkg.shape[1]) # height, width
+
+ batch_size, vertex_num = mesh.shape[:2]
+ textures = TexturesVertex(verts_features=torch.ones((batch_size,vertex_num,3)).float().cuda())
+ mesh = torch.stack((-mesh[:,:,0], -mesh[:,:,1], mesh[:,:,2]),2) # reverse x- and y-axis following PyTorch3D axis direction
+ mesh = Meshes(mesh, face, textures)
+
+ if R is None:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2))
+ else:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2),
+ R=R,
+ T=T)
+
+ raster_settings = RasterizationSettings(image_size=render_shape, blur_radius=0.0, faces_per_pixel=1, bin_size=0)
+ rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings).cuda()
+ lights = PointLights(device='cuda')
+ shader = SoftPhongShader(device='cuda', cameras=cameras, lights=lights)
+ materials = Materials(
+ device='cuda',
+ specular_color=[[0.0, 0.0, 0.0]],
+ shininess=0.0
+ )
+
+ # render
+ with torch.no_grad():
+ renderer = MeshRendererWithFragments(rasterizer=rasterizer, shader=shader)
+ images, fragments = renderer(mesh, materials=materials)
+
+ # background masking
+ is_bkg = (fragments.zbuf <= 0).float().cpu().numpy()[0]
+ render = images[0,:,:,:3].cpu().numpy()
+ fg = render * blend_ratio + bkg/255 * (1 - blend_ratio)
+ render = fg * (1 - is_bkg) * 255 + bkg * is_bkg
+ ret = [render]
+ if return_bg_mask:
+ ret.append(is_bkg)
+ if return_fragments:
+ ret.append(fragments)
+ return tuple(ret)
+
+
+def rasterize_mesh(mesh, face, cam_param, height, width, return_bg_mask=False, R=None, T=None):
+ mesh = mesh.cuda()[None,:,:]
+ face = face.long().cuda()[None,:,:]
+ cam_param = {k: v.cuda()[None,:] for k,v in cam_param.items()}
+ render_shape = (height, width)
+
+ batch_size, vertex_num = mesh.shape[:2]
+ textures = TexturesVertex(verts_features=torch.ones((batch_size,vertex_num,3)).float().cuda())
+ mesh = torch.stack((-mesh[:,:,0], -mesh[:,:,1], mesh[:,:,2]),2) # reverse x- and y-axis following PyTorch3D axis direction
+ mesh = Meshes(mesh, face, textures)
+
+ if R is None:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2))
+ else:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2),
+ R=R,
+ T=T)
+
+ raster_settings = RasterizationSettings(image_size=render_shape, blur_radius=0.0, faces_per_pixel=1, bin_size=0)
+ rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings).cuda()
+
+ # render
+ fragments = rasterizer(mesh)
+
+ ret = [fragments]
+
+ if return_bg_mask:
+ # background masking
+ is_bkg = (fragments.zbuf <= 0).float().cpu().numpy()[0]
+ ret.append(is_bkg)
+
+ return tuple(ret)
+
+
+def rasterize_points(points, cam_param, height, width, return_bg_mask=False, R=None, T=None, to_cpu=False, points_per_pixel=5, radius=0.01):
+ points = torch.stack((-points[:, 0], -points[:, 1], points[:, 2]), 1) # reverse x- and y-axis following PyTorch3D axis direction
+ device = points.device
+ if len(points.shape) == 2:
+ points = [points]
+ pointclouds = Pointclouds(points=points)
+ cam_param = {k: v.to(device)[None,:] for k,v in cam_param.items()}
+ render_shape = (height, width) # height, width
+
+ if R is None:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device=device,
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).to(device).view(1,2))
+ else:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device=device,
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).to(device).view(1,2),
+ R=R,
+ T=T)
+
+ raster_settings = PointsRasterizationSettings(image_size=render_shape, radius=radius, points_per_pixel=points_per_pixel, max_points_per_bin=82000)
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings).to(device)
+
+ # render
+ fragments = rasterizer(pointclouds)
+
+ # background masking
+ ret = [fragments]
+ if return_bg_mask:
+ if to_cpu:
+ is_bkg = (fragments.zbuf <= 0).all(dim=-1, keepdim=True).float().cpu().numpy()[0]
+ else:
+ is_bkg = (fragments.zbuf <= 0).all(dim=-1, keepdim=True).float()[0]
+ ret.append(is_bkg)
+
+ return tuple(ret)
+
+
+def render_points(points, cam_param, bkg, blend_ratio=1.0, return_bg_mask=False, R=None, T=None, return_fragments=False, rgbs=None):
+ points = torch.stack((-points[:, 0], -points[:, 1], points[:, 2]), 1) # reverse x- and y-axis following PyTorch3D axis direction
+ if rgbs is None:
+ rgbs = torch.ones_like(points)
+ if len(points.shape) == 2:
+ points = [points]
+ rgbs = [rgbs]
+ pointclouds = Pointclouds(points=points, features=rgbs).cuda()
+ cam_param = {k: v.cuda()[None,:] for k,v in cam_param.items()}
+ render_shape = (bkg.shape[0], bkg.shape[1]) # height, width
+
+ if R is None:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2))
+ else:
+ cameras = PerspectiveCameras(focal_length=cam_param['focal'],
+ principal_point=cam_param['princpt'],
+ device='cuda',
+ in_ndc=False,
+ image_size=torch.LongTensor(render_shape).cuda().view(1,2),
+ R=R,
+ T=T)
+
+ raster_settings = PointsRasterizationSettings(image_size=render_shape, radius=0.01, points_per_pixel=5)
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings).cuda()
+
+ # render
+ with torch.no_grad():
+ fragments = rasterizer(pointclouds)
+ renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor(background_color=(0, 0, 0)))
+ images = renderer(pointclouds)
+
+ # background masking
+ is_bkg = (fragments.zbuf <= 0).all(dim=-1, keepdim=True).float().cpu().numpy()[0]
+ render = images[0,:,:,:3].cpu().numpy()
+ fg = render * blend_ratio + bkg/255 * (1 - blend_ratio)
+ render = fg * (1 - is_bkg) * 255 + bkg * is_bkg
+
+ ret = [render]
+ if return_bg_mask:
+ ret.append(is_bkg)
+ if return_fragments:
+ ret.append(fragments)
+ return tuple(ret)
+
+
+class RenderMesh(nn.Module):
+ def __init__(self, image_size, obj_filename=None, faces=None, device='cpu'):
+ super(RenderMesh, self).__init__()
+ self.device = device
+ self.image_size = image_size
+ if obj_filename is not None:
+ verts, faces, aux = load_obj(obj_filename, load_textures=False)
+ self.faces = faces.verts_idx
+ elif faces is not None:
+ import numpy as np
+ self.faces = torch.tensor(faces.astype(np.int32))
+ else:
+ raise NotImplementedError('Must have faces.')
+ self.raster_settings = RasterizationSettings(image_size=image_size, blur_radius=0.0, faces_per_pixel=1)
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
+
+ def _build_cameras(self, transform_matrix, focal_length, principal_point=None, intr=None):
+ batch_size = transform_matrix.shape[0]
+ screen_size = torch.tensor(
+ [self.image_size, self.image_size], device=self.device
+ ).float()[None].repeat(batch_size, 1)
+ if principal_point is None:
+ principal_point = torch.zeros(batch_size, 2, device=self.device).float()
+ # print("==="*16, "principle_points:", principal_point)
+ # print("==="*16, "focal_length:", focal_length)
+ if intr is None:
+ cameras_kwargs = {
+ 'principal_point': principal_point, 'focal_length': focal_length,
+ 'image_size': screen_size, 'device': self.device,
+ }
+ else:
+ cameras_kwargs = {
+ 'principal_point': principal_point, 'focal_length': torch.tensor([intr[0, 0], intr[1, 1]]).unsqueeze(0),
+ 'image_size': screen_size, 'device': self.device,
+ }
+ cameras = PerspectiveCameras(**cameras_kwargs, R=transform_matrix[:, :3, :3], T=transform_matrix[:, :3, 3])
+ return cameras
+
+ def forward(
+ self, vertices, cameras=None, transform_matrix=None, focal_length=None, principal_point=None, only_rasterize=False, intr=None,
+ ):
+ if cameras is None:
+ cameras = self._build_cameras(transform_matrix, focal_length, principal_point=principal_point, intr=intr)
+ faces = self.faces[None].repeat(vertices.shape[0], 1, 1)
+ # Initialize each vertex to be white in color.
+ verts_rgb = torch.ones_like(vertices) # (1, V, 3)
+ textures = TexturesVertex(verts_features=verts_rgb.to(self.device))
+ mesh = Meshes(
+ verts=vertices.to(self.device),
+ faces=faces.to(self.device),
+ textures=textures
+ )
+ renderer = MeshRendererWithFragments(
+ rasterizer=MeshRasterizer(cameras=cameras, raster_settings=self.raster_settings),
+ shader=SoftPhongShader(cameras=cameras, lights=self.lights, device=self.device)
+ )
+ render_results, fragments = renderer(mesh)
+ render_results = render_results.permute(0, 3, 1, 2)
+ if only_rasterize:
+ return fragments
+ images = render_results[:, :3]
+ alpha_images = render_results[:, 3:]
+ images[alpha_images.expand(-1, 3, -1, -1)<0.5] = 0.0
+ return images*255, alpha_images
+
+
+class RenderPoints(nn.Module):
+ def __init__(self, image_size, obj_filename=None, device='cpu'):
+ super(RenderPoints, self).__init__()
+ self.device = device
+ self.image_size = image_size
+ if obj_filename is not None:
+ verts = load_obj(obj_filename, load_textures=False)
+ self.raster_settings = PointsRasterizationSettings(image_size=image_size, radius=0.01, points_per_pixel=1)
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
+
+ def _build_cameras(self, transform_matrix, focal_length, principal_point=None):
+ batch_size = transform_matrix.shape[0]
+ screen_size = torch.tensor(
+ [self.image_size, self.image_size], device=self.device
+ ).float()[None].repeat(batch_size, 1)
+ if principal_point is None:
+ principal_point = torch.zeros(batch_size, 2, device=self.device).float()
+ # print("==="*16, "principle_points:", principal_point)
+ # print("==="*16, "focal_length:", focal_length)
+ cameras_kwargs = {
+ 'principal_point': principal_point, 'focal_length': focal_length,
+ 'image_size': screen_size, 'device': self.device,
+ }
+ cameras = PerspectiveCameras(**cameras_kwargs, R=transform_matrix[:, :3, :3], T=transform_matrix[:, :3, 3])
+ return cameras
+
+ def forward(
+ self, vertices, cameras=None, transform_matrix=None, focal_length=None, principal_point=None, only_rasterize=False
+ ):
+ if cameras is None:
+ cameras = self._build_cameras(transform_matrix, focal_length, principal_point=principal_point)
+ # Initialize each vertex to be white in color.
+ verts_rgb = torch.ones_like(vertices) # (1, V, 3)
+ pointclouds = Pointclouds(points=vertices, features=verts_rgb).cuda()
+
+ # render
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=self.raster_settings).cuda()
+ if only_rasterize:
+ fragments = rasterizer(pointclouds)
+ return fragments
+ renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor(background_color=(0, 0, 0)))
+ render_results = renderer(pointclouds).permute(0, 3, 1, 2)
+ images = render_results[:, :3]
+ alpha_images = render_results[:, 3:]
+
+ return images*255, alpha_images
\ No newline at end of file
diff --git a/lam/models/transformer.py b/lam/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e9b7f7556f9ed5d4f8fa77d9201770b7558550
--- /dev/null
+++ b/lam/models/transformer.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from functools import partial
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+from typing import Any, Dict, Optional, Tuple, Union
+from diffusers.utils import is_torch_version
+
+logger = get_logger(__name__)
+
+
+class TransformerDecoder(nn.Module):
+
+ """
+ Transformer blocks that process the input and optionally use condition and modulation.
+ """
+
+ def __init__(self, block_type: str,
+ num_layers: int, num_heads: int,
+ inner_dim: int, cond_dim: int = None, mod_dim: int = None,
+ gradient_checkpointing=False,
+ eps: float = 1e-6,
+ use_dual_attention: bool = False,):
+ super().__init__()
+ self.gradient_checkpointing = gradient_checkpointing
+ self.block_type = block_type
+ if block_type == "sd3_cond":
+ # dual_attention_layers = list(range(num_layers//2))
+ dual_attention_layers = []
+ self.layers = nn.ModuleList([
+ self._block_fn(inner_dim, cond_dim, mod_dim)(
+ num_heads=num_heads,
+ eps=eps,
+ context_pre_only=i == num_layers - 1,
+ use_dual_attention=use_dual_attention, # True if i in dual_attention_layers else False,
+ )
+ for i in range(num_layers)
+ ])
+ else:
+ self.layers = nn.ModuleList([
+ self._block_fn(inner_dim, cond_dim, mod_dim)(
+ num_heads=num_heads,
+ eps=eps,
+ )
+ for _ in range(num_layers)
+ ])
+
+
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
+
+ if self.block_type in ["cogvideo_cond", "sd3_cond"]:
+ self.linear_cond_proj = nn.Linear(cond_dim, inner_dim)
+
+ @property
+ def block_type(self):
+ return self._block_type
+
+ @block_type.setter
+ def block_type(self, block_type):
+ assert block_type in ['basic', 'cond', 'mod', 'cond_mod', 'sd3_cond', 'cogvideo_cond'], \
+ f"Unsupported block type: {block_type}"
+ self._block_type = block_type
+
+ def _block_fn(self, inner_dim, cond_dim, mod_dim):
+ assert inner_dim is not None, f"inner_dim must always be specified"
+ if self.block_type == 'basic':
+ assert cond_dim is None and mod_dim is None, \
+ f"Condition and modulation are not supported for BasicBlock"
+ from .block import BasicBlock
+ # logger.debug(f"Using BasicBlock")
+ return partial(BasicBlock, inner_dim=inner_dim)
+ elif self.block_type == 'cond':
+ assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock"
+ assert mod_dim is None, f"Modulation dimension is not supported for ConditionBlock"
+ from .block import ConditionBlock
+ # logger.debug(f"Using ConditionBlock")
+ return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim)
+ elif self.block_type == 'mod':
+ # logger.error(f"modulation without condition is not implemented")
+ raise NotImplementedError(f"modulation without condition is not implemented")
+ elif self.block_type == 'cond_mod':
+ assert cond_dim is not None and mod_dim is not None, \
+ f"Condition and modulation dimensions must be specified for ConditionModulationBlock"
+ from .block import ConditionModulationBlock
+ # logger.debug(f"Using ConditionModulationBlock")
+ return partial(ConditionModulationBlock, inner_dim=inner_dim, cond_dim=cond_dim, mod_dim=mod_dim)
+ elif self.block_type == 'cogvideo_cond':
+ # logger.debug(f"Using CogVideoXBlock")
+ from lam.models.transformer_dit import CogVideoXBlock
+ # assert inner_dim == cond_dim, f"inner_dim:{inner_dim}, cond_dim:{cond_dim}"
+ return partial(CogVideoXBlock, dim=inner_dim, attention_bias=True)
+ elif self.block_type == 'sd3_cond':
+ # logger.debug(f"Using SD3JointTransformerBlock")
+ from lam.models.transformer_dit import SD3JointTransformerBlock
+ return partial(SD3JointTransformerBlock, dim=inner_dim, qk_norm="rms_norm")
+ else:
+ raise ValueError(f"Unsupported block type during runtime: {self.block_type}")
+
+ def assert_runtime_integrity(self, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor):
+ assert x is not None, f"Input tensor must be specified"
+ if self.block_type == 'basic':
+ assert cond is None and mod is None, \
+ f"Condition and modulation are not supported for BasicBlock"
+ elif 'cond' in self.block_type:
+ assert cond is not None and mod is None, \
+ f"Condition must be specified and modulation is not supported for ConditionBlock"
+ elif self.block_type == 'mod':
+ raise NotImplementedError(f"modulation without condition is not implemented")
+ else:
+ assert cond is not None and mod is not None, \
+ f"Condition and modulation must be specified for ConditionModulationBlock"
+
+ def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor):
+ if self.block_type == 'basic':
+ return layer(x)
+ elif self.block_type == 'cond':
+ return layer(x, cond)
+ elif self.block_type == 'mod':
+ return layer(x, mod)
+ else:
+ return layer(x, cond, mod)
+
+ def forward(self, x: torch.Tensor, cond: torch.Tensor = None, mod: torch.Tensor = None):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond] or None
+ # mod: [N, D_mod] or None
+ self.assert_runtime_integrity(x, cond, mod)
+
+ if self.block_type in ["cogvideo_cond", "sd3_cond"]:
+ cond = self.linear_cond_proj(cond)
+ for layer in self.layers:
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ x, cond = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ x,
+ cond,
+ **ckpt_kwargs,
+ )
+ else:
+ x, cond = layer(
+ hidden_states=x,
+ encoder_hidden_states=cond,
+ temb=None,
+ # image_rotary_emb=None,
+ )
+ x = self.norm(x)
+ else:
+ for layer in self.layers:
+ x = self.forward_layer(layer, x, cond, mod)
+ x = self.norm(x)
+ return x
+
+
+
diff --git a/lam/models/transformer_dit.py b/lam/models/transformer_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..360a697aea1328870a283894b516380e9c5183a6
--- /dev/null
+++ b/lam/models/transformer_dit.py
@@ -0,0 +1,410 @@
+from functools import partial
+import torch
+import torch.nn as nn
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch.nn.functional as F
+assert hasattr(F, "scaled_dot_product_attention")
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, JointAttnProcessor2_0
+
+
+
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ # num_attention_heads: int,
+ # attention_head_dim: int,
+ # time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ # norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ norm_eps = eps
+ num_attention_heads = num_heads
+ attention_head_dim = dim // num_attention_heads
+ assert attention_head_dim * num_attention_heads == dim
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps, bias=True)
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps, bias=True)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ # norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ # hidden_states, encoder_hidden_states, temb
+ # )
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
+
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states
+
+ # norm & modulate
+ # norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ # hidden_states, encoder_hidden_states, temb
+ # )
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+class QKNormJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class SD3JointTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ eps: float,
+ # num_attention_heads: int,
+ # attention_head_dim: int,
+ context_pre_only: bool = False,
+ qk_norm: Optional[str] = None,
+ use_dual_attention: bool = False,
+ ):
+ super().__init__()
+ num_attention_heads = num_heads
+ attention_head_dim = dim // num_attention_heads
+ assert attention_head_dim * num_attention_heads == dim
+
+ self.use_dual_attention = use_dual_attention
+ self.context_pre_only = context_pre_only
+ # context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
+
+ # if use_dual_attention:
+ # self.norm1 = SD35AdaLayerNormZeroX(dim)
+ # else:
+ # self.norm1 = AdaLayerNormZero(dim)
+
+ self.norm1 = nn.LayerNorm(dim)
+
+ # if context_norm_type == "ada_norm_continous":
+ # self.norm1_context = AdaLayerNormContinuous(
+ # dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
+ # )
+ # elif context_norm_type == "ada_norm_zero":
+ # self.norm1_context = AdaLayerNormZero(dim)
+ # else:
+ # raise ValueError(
+ # f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
+ # )
+ # self.norm1_context = AdaLayerNormZero(dim)
+
+ self.norm1_context = nn.LayerNorm(dim)
+
+ processor = JointAttnProcessor2_0()
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=context_pre_only,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+
+ if use_dual_attention:
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+ else:
+ self.attn2 = None
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ if not context_pre_only:
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ else:
+ self.norm2_context = None
+ self.ff_context = None
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor=None
+ ):
+ # if self.use_dual_attention:
+ # norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+ # hidden_states, emb=temb
+ # )
+ # else:
+ # norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ # if self.context_pre_only:
+ # norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ # else:
+ # norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ # encoder_hidden_states, emb=temb
+ # )
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ # attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ if self.use_dual_attention:
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states)
+ # attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
+ hidden_states = hidden_states + attn_output2
+
+ norm_hidden_states = self.norm2(hidden_states)
+ # norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+ # ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ # context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ # norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ # encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
+
+ return hidden_states, encoder_hidden_states
\ No newline at end of file
diff --git a/lam/runners/__init__.py b/lam/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5ce7bb79c8c037b392a178920695bf46d89f9f6
--- /dev/null
+++ b/lam/runners/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from lam.utils.registry import Registry
+
+REGISTRY_RUNNERS = Registry()
+
+from .train import *
+from .infer import *
diff --git a/lam/runners/abstract.py b/lam/runners/abstract.py
new file mode 100644
index 0000000000000000000000000000000000000000..76916e805a5cfbf333d2d63e8607811939a5a639
--- /dev/null
+++ b/lam/runners/abstract.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+
+
+class Runner(ABC):
+ """Abstract runner class"""
+
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def run(self):
+ pass
diff --git a/lam/runners/infer/__init__.py b/lam/runners/infer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..126dc668f176d74c7ff8ab6c0981c26bf3a7950b
--- /dev/null
+++ b/lam/runners/infer/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .lam import LAMInferrer
diff --git a/lam/runners/infer/base_inferrer.py b/lam/runners/infer/base_inferrer.py
new file mode 100644
index 0000000000000000000000000000000000000000..10dcde70b67e29a14d0683432e06e76b929098d8
--- /dev/null
+++ b/lam/runners/infer/base_inferrer.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from abc import abstractmethod
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+
+from lam.runners.abstract import Runner
+
+
+logger = get_logger(__name__)
+
+
+class Inferrer(Runner):
+
+ EXP_TYPE: str = None
+
+ def __init__(self):
+ super().__init__()
+
+ torch._dynamo.config.disable = True
+ self.accelerator = Accelerator()
+
+ self.model : torch.nn.Module = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @abstractmethod
+ def _build_model(self, cfg):
+ pass
+
+ @abstractmethod
+ def infer_single(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def infer(self):
+ pass
+
+ def run(self):
+ self.infer()
diff --git a/lam/runners/infer/head_utils.py b/lam/runners/infer/head_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..85683787dc658586765e29102b100461d81e28c0
--- /dev/null
+++ b/lam/runners/infer/head_utils.py
@@ -0,0 +1,633 @@
+from collections import defaultdict
+import glob
+import os
+import json
+import numpy as np
+from PIL import Image
+import cv2
+import torch
+import decord
+import pickle as pkl
+
+
+def scale_intrs(intrs, ratio_x, ratio_y):
+ if len(intrs.shape) >= 3:
+ intrs[:, 0] = intrs[:, 0] * ratio_x
+ intrs[:, 1] = intrs[:, 1] * ratio_y
+ else:
+ intrs[0] = intrs[0] * ratio_x
+ intrs[1] = intrs[1] * ratio_y
+ return intrs
+
+def calc_new_tgt_size(cur_hw, tgt_size, multiply):
+ ratio = tgt_size / min(cur_hw)
+ tgt_size = int(ratio * cur_hw[0]), int(ratio * cur_hw[1])
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+def calc_new_tgt_size_by_aspect(cur_hw, aspect_standard, tgt_size, multiply):
+ assert abs(cur_hw[0] / cur_hw[1] - aspect_standard) < 0.03
+ tgt_size = tgt_size * aspect_standard, tgt_size
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+
+def img_center_padding(img_np, pad_ratio):
+
+ ori_w, ori_h = img_np.shape[:2]
+
+ w = round((1 + pad_ratio) * ori_w)
+ h = round((1 + pad_ratio) * ori_h)
+
+ if len(img_np.shape) > 2:
+ img_pad_np = np.zeros((w, h, img_np.shape[2]), dtype=np.uint8)
+ else:
+ img_pad_np = np.zeros((w, h), dtype=np.uint8)
+ offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2
+ img_pad_np[offset_h: offset_h + img_np.shape[0]:, offset_w: offset_w + img_np.shape[1]] = img_np
+
+ return img_pad_np
+
+
+def resize_image_keepaspect_np(img, max_tgt_size):
+ """
+ similar to ImageOps.contain(img_pil, (img_size, img_size)) # keep the same aspect ratio
+ """
+ h, w = img.shape[:2]
+ ratio = max_tgt_size / max(h, w)
+ new_h, new_w = round(h * ratio), round(w * ratio)
+ return cv2.resize(img, dsize=(new_w, new_h), interpolation=cv2.INTER_AREA)
+
+
+def center_crop_according_to_mask(img, mask, aspect_standard, enlarge_ratio):
+ """
+ img: [H, W, 3]
+ mask: [H, W]
+ """
+ if len(mask.shape) > 2:
+ mask = mask[:, :, 0]
+ ys, xs = np.where(mask > 0)
+
+ if len(xs) == 0 or len(ys) == 0:
+ raise Exception("empty mask")
+
+ x_min = np.min(xs)
+ x_max = np.max(xs)
+ y_min = np.min(ys)
+ y_max = np.max(ys)
+
+ center_x, center_y = img.shape[1]//2, img.shape[0]//2
+
+ half_w = max(abs(center_x - x_min), abs(center_x - x_max))
+ half_h = max(abs(center_y - y_min), abs(center_y - y_max))
+ half_w_raw = half_w
+ half_h_raw = half_h
+ aspect = half_h / half_w
+
+ if aspect >= aspect_standard:
+ half_w = round(half_h / aspect_standard)
+ else:
+ half_h = round(half_w * aspect_standard)
+
+ if half_h > center_y:
+ half_w = round(half_h_raw / aspect_standard)
+ half_h = half_h_raw
+ if half_w > center_x:
+ half_h = round(half_w_raw * aspect_standard)
+ half_w = half_w_raw
+
+ if abs(enlarge_ratio[0] - 1) > 0.01 or abs(enlarge_ratio[1] - 1) > 0.01:
+ enlarge_ratio_min, enlarge_ratio_max = enlarge_ratio
+ enlarge_ratio_max_real = min(center_y / half_h, center_x / half_w)
+ enlarge_ratio_max = min(enlarge_ratio_max_real, enlarge_ratio_max)
+ enlarge_ratio_min = min(enlarge_ratio_max_real, enlarge_ratio_min)
+ enlarge_ratio_cur = np.random.rand() * (enlarge_ratio_max - enlarge_ratio_min) + enlarge_ratio_min
+ half_h, half_w = round(enlarge_ratio_cur * half_h), round(enlarge_ratio_cur * half_w)
+
+ assert half_h <= center_y
+ assert half_w <= center_x
+ assert abs(half_h / half_w - aspect_standard) < 0.03
+
+ offset_x = center_x - half_w
+ offset_y = center_y - half_h
+
+ new_img = img[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+ new_mask = mask[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+
+ return new_img, new_mask, offset_x, offset_y
+
+
+def preprocess_image(rgb_path, mask_path, intr, pad_ratio, bg_color,
+ max_tgt_size, aspect_standard, enlarge_ratio,
+ render_tgt_size, multiply, need_mask=True,
+ get_shape_param=False):
+ rgb = np.array(Image.open(rgb_path))
+ rgb_raw = rgb.copy()
+ if pad_ratio > 0:
+ rgb = img_center_padding(rgb, pad_ratio)
+
+ rgb = rgb / 255.0
+ if need_mask:
+ if rgb.shape[2] < 4:
+ if mask_path is not None:
+ # mask = np.array(Image.open(mask_path))
+ mask = (np.array(Image.open(mask_path)) > 180) * 255
+ else:
+ from rembg import remove
+ mask = remove(rgb_raw[:, :, (2, 1, 0)])[:, :, -1] # np require [bgr]
+ print("rmbg mask: ", mask.min(), mask.max(), mask.shape)
+ if pad_ratio > 0:
+ mask = img_center_padding(mask, pad_ratio)
+ mask = mask / 255.0
+ else:
+ # rgb: [H, W, 4]
+ assert rgb.shape[2] == 4
+ mask = rgb[:, :, 3] # [H, W]
+ else:
+ # just placeholder
+ mask = np.ones_like(rgb[:, :, 0])
+ if len(mask.shape) > 2:
+ mask = mask[:, :, 0]
+
+ # mask = (mask > 0.5).astype(np.float32)
+ mask = mask.astype(np.float32)
+ if (rgb.shape[0] == rgb.shape[1]) and (rgb.shape[0]==512):
+ rgb = cv2.resize(rgb, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_AREA)
+ rgb = rgb[:, :, :3] * mask[:, :, None] + bg_color * (1 - mask[:, :, None])
+
+ # # resize to specific size require by preprocessor of flame-estimator.
+ # rgb = resize_image_keepaspect_np(rgb, max_tgt_size)
+ # mask = resize_image_keepaspect_np(mask, max_tgt_size)
+
+ # crop image to enlarge human area.
+ rgb, mask, offset_x, offset_y = center_crop_according_to_mask(rgb, mask, aspect_standard, enlarge_ratio)
+ if intr is not None:
+ intr[0, 2] -= offset_x
+ intr[1, 2] -= offset_y
+
+ # resize to render_tgt_size for training
+ tgt_hw_size, ratio_y, ratio_x = calc_new_tgt_size_by_aspect(cur_hw=rgb.shape[:2],
+ aspect_standard=aspect_standard,
+ tgt_size=render_tgt_size, multiply=multiply)
+ rgb = cv2.resize(rgb, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=cv2.INTER_AREA)
+ mask = cv2.resize(mask, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=cv2.INTER_AREA)
+
+ if intr is not None:
+ intr = scale_intrs(intr, ratio_x=ratio_x, ratio_y=ratio_y)
+ assert abs(intr[0, 2] * 2 - rgb.shape[1]) < 2.5, f"{intr[0, 2] * 2}, {rgb.shape[1]}"
+ assert abs(intr[1, 2] * 2 - rgb.shape[0]) < 2.5, f"{intr[1, 2] * 2}, {rgb.shape[0]}"
+ intr[0, 2] = rgb.shape[1] // 2
+ intr[1, 2] = rgb.shape[0] // 2
+
+ rgb = torch.from_numpy(rgb).float().permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
+ mask = torch.from_numpy(mask[:, :, None]).float().permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
+
+ # read shape_param
+ shape_param = None
+ if get_shape_param:
+ cor_flame_path = os.path.join(os.path.dirname(os.path.dirname(rgb_path)),'canonical_flame_param.npz')
+ flame_p = np.load(cor_flame_path)
+ shape_param = torch.FloatTensor(flame_p['shape'])
+
+ return rgb, mask, intr, shape_param
+
+
+def extract_imgs_from_video(video_file, save_root, fps):
+ print(f"extract_imgs_from_video:{video_file}")
+ vr = decord.VideoReader(video_file)
+ for i in range(0, len(vr), fps):
+ frame = vr[i].asnumpy()
+ save_path = os.path.join(save_root, f"{i:05d}.jpg")
+ cv2.imwrite(save_path, frame[:, :, (2, 1, 0)])
+
+def predict_motion_seqs_from_images(image_folder:str, save_root, fps=6):
+ id_name = os.path.splitext(os.path.basename(image_folder))[0]
+ if os.path.isfile(image_folder) and (image_folder.endswith("mp4") or image_folder.endswith("move")):
+ save_frame_root = os.path.join(save_root, "extracted_frames", id_name)
+ if not os.path.exists(save_frame_root):
+ os.makedirs(save_frame_root, exist_ok=True)
+ extract_imgs_from_video(video_file=image_folder, save_root=save_frame_root, fps=fps)
+ else:
+ print("skip extract_imgs_from_video......")
+ image_folder = save_frame_root
+
+ image_folder_abspath = os.path.abspath(image_folder)
+ print(f"predict motion seq:{image_folder_abspath}")
+ save_flame_root = image_folder + "_flame_params_mhmr"
+ if not os.path.exists(save_flame_root):
+ cmd = f"cd thirdparty/multi-hmr && python infer_batch.py --data_root {image_folder_abspath} --out_folder {image_folder_abspath} --crop_head --crop_hand --pad_ratio 0.2 --smplify"
+ os.system(cmd)
+ else:
+ print("skip predict flame.........")
+ return save_flame_root, image_folder
+
+
+def render_flame_mesh(data, render_intrs, c2ws, human_model_path="./pretrained_models/human_model_files"):
+ from lam.models.rendering.flame_model.flame import FlameHead, FlameHeadSubdivided
+ from lam.models.rendering.utils.vis_utils import render_mesh
+
+ subdivide = 2
+ flame_sub_model = FlameHeadSubdivided(
+ 300,
+ 100,
+ add_teeth=True,
+ add_shoulder=False,
+ flame_model_path='pretrained_models/human_model_files/flame_assets/flame/flame2023.pkl',
+ flame_lmk_embedding_path="pretrained_models/human_model_files/flame_assets/flame/landmark_embedding_with_eyes.npy",
+ flame_template_mesh_path="pretrained_models/human_model_files/flame_assets/flame/head_template_mesh.obj",
+ flame_parts_path="pretrained_models/human_model_files/flame_assets/flame/FLAME_masks.pkl",
+ subdivide_num=subdivide
+ ).cuda()
+
+ shape = data['betas'].to('cuda')
+ flame_param = {}
+ flame_param['expr'] = data['expr'].to('cuda')
+ flame_param['rotation'] = data['rotation'].to('cuda')
+ flame_param['neck'] = data['neck_pose'].to('cuda')
+ flame_param['jaw'] = data['jaw_pose'].to('cuda')
+ flame_param['eyes'] = data['eyes_pose'].to('cuda')
+ flame_param['translation'] = data['translation'].to('cuda')
+
+ v_cano = flame_sub_model.get_cano_verts(
+ shape.unsqueeze(0)
+ )
+
+ ret = flame_sub_model.animation_forward(
+ v_cano.repeat(flame_param['expr'].shape[0], 1, 1),
+ shape.unsqueeze(0).repeat(flame_param['expr'].shape[0], 1),
+ flame_param['expr'],
+ flame_param['rotation'],
+ flame_param['neck'],
+ flame_param['jaw'],
+ flame_param['eyes'],
+ flame_param['translation'],
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=True,
+ # static_offset=batch_data['static_offset'].to('cuda'),
+ static_offset=None,
+ )
+
+ flame_face = flame_sub_model.faces.cpu().squeeze().numpy()
+ mesh_render_list = []
+ num_view = flame_param['expr'].shape[0]
+ for v_idx in range(num_view):
+ intr = render_intrs[v_idx]
+ cam_param = {"focal": torch.tensor([intr[0, 0], intr[1, 1]]),
+ "princpt": torch.tensor([intr[0, 2], intr[1, 2]])}
+ render_shape = int(cam_param['princpt'][1]* 2), int(cam_param['princpt'][0] * 2) # require h, w
+
+ vertices = ret["animated"][v_idx].cpu().squeeze()
+
+ c2w = c2ws[v_idx]
+ w2c = torch.inverse(c2w)
+ R = w2c[:3, :3]
+ T = w2c[:3, 3]
+ vertices = vertices @ R + T
+
+ mesh_render, is_bkg = render_mesh(vertices,
+ flame_face, cam_param,
+ np.ones((render_shape[0],render_shape[1], 3), dtype=np.float32)*255,
+ return_bg_mask=True)
+ mesh_render = mesh_render.astype(np.uint8)
+ mesh_render_list.append(mesh_render)
+ mesh_render = np.stack(mesh_render_list)
+ return mesh_render
+
+def render_flame_mesh_gaga19(data, render_intrs, c2ws, human_model_path="./pretrained_models/human_model_files"):
+ subdivide = 2
+ from lam.models.rendering.flame_model.flame import FlameHeadSubdivided
+ flame_sub_model = FlameHeadSubdivided(
+ 300,
+ 100,
+ add_teeth=True,
+ add_shoulder=False,
+ flame_model_path='pretrained_models/human_model_files/flame_assets/flame/flame2023.pkl',
+ flame_lmk_embedding_path="pretrained_models/human_model_files/flame_assets/flame/landmark_embedding_with_eyes.npy",
+ flame_template_mesh_path="pretrained_models/human_model_files/flame_assets/flame/head_template_mesh.obj",
+ flame_parts_path="pretrained_models/human_model_files/flame_assets/flame/FLAME_masks.pkl",
+ subdivide_num=subdivide
+ ).cuda()
+
+ shape = data['betas'].to('cuda')
+ flame_param = {}
+ flame_param['expr'] = data['expr'].to('cuda')
+ flame_param['rotation'] = data['rotation'].to('cuda')
+ flame_param['neck'] = data['neck_pose'].to('cuda')
+ flame_param['jaw'] = data['jaw_pose'].to('cuda')
+ flame_param['eyes'] = data['eyes_pose'].to('cuda')
+ flame_param['translation'] = data['translation'].to('cuda')
+
+ v_cano = flame_sub_model.get_cano_verts(
+ shape.unsqueeze(0)
+ )
+
+ ret = flame_sub_model.animation_forward(
+ v_cano.repeat(flame_param['expr'].shape[0], 1, 1),
+ shape.unsqueeze(0).repeat(flame_param['expr'].shape[0], 1),
+ flame_param['expr'],
+ flame_param['rotation'],
+ flame_param['neck'],
+ flame_param['jaw'],
+ flame_param['eyes'],
+ flame_param['translation'],
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=True,
+ # static_offset=batch_data['static_offset'].to('cuda'),
+ static_offset=None,
+ )
+
+ flame_face = flame_sub_model.faces.cpu().squeeze().numpy()
+ mesh_render_list = []
+ num_view = flame_param['expr'].shape[0]
+ import trimesh
+ from lam.models.rendering.flame.vis_utils import RenderMesh
+ for v_idx in range(num_view):
+ mesh = trimesh.Trimesh()
+ mesh.vertices = np.array(ret["animated"][v_idx].cpu().squeeze())
+ mesh.faces = np.array(flame_sub_model.faces.cpu().squeeze())
+
+ renderer = RenderMesh(512, faces=mesh.faces, device="cuda")
+ render_img, _ = renderer(ret["animated"][[v_idx]], focal_length=12.0, transform_matrix=c2ws[[v_idx]])
+ render_img = render_img[0].permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)
+ mesh_render_list.append(render_img)
+ mesh_render = np.stack(mesh_render_list)
+ return mesh_render
+
+
+def _load_pose(frame_info):
+ c2w = torch.eye(4)
+ c2w = np.array(frame_info["transform_matrix"])
+ c2w[:3, 1:3] *= -1
+ c2w = torch.FloatTensor(c2w)
+
+ intrinsic = torch.eye(4)
+ intrinsic[0, 0] = frame_info["fl_x"]
+ intrinsic[1, 1] = frame_info["fl_y"]
+ intrinsic[0, 2] = frame_info["cx"]
+ intrinsic[1, 2] = frame_info["cy"]
+ intrinsic = intrinsic.float()
+
+ return c2w, intrinsic
+
+def load_flame_params(flame_file_path, teeth_bs=None):
+
+ flame_param = dict(np.load(flame_file_path, allow_pickle=True))
+
+ flame_param_tensor = {}
+ flame_param_tensor['expr'] = torch.FloatTensor(flame_param['expr'])[0]
+ flame_param_tensor['rotation'] = torch.FloatTensor(flame_param['rotation'])[0]
+ flame_param_tensor['neck_pose'] = torch.FloatTensor(flame_param['neck_pose'])[0]
+ flame_param_tensor['jaw_pose'] = torch.FloatTensor(flame_param['jaw_pose'])[0]
+ flame_param_tensor['eyes_pose'] = torch.FloatTensor(flame_param['eyes_pose'])[0]
+ flame_param_tensor['translation'] = torch.FloatTensor(flame_param['translation'])[0]
+ if teeth_bs is not None:
+ flame_param_tensor['teeth_bs'] = torch.FloatTensor(teeth_bs)
+
+ return flame_param_tensor
+
+def prepare_motion_seqs(motion_seqs_dir, image_folder, save_root, fps,
+ bg_color, aspect_standard, enlarge_ratio,
+ render_image_res, need_mask, multiply=16,
+ vis_motion=False, shape_param=None, test_sample=False, cross_id=False, src_driven=["", ""],
+ max_squen_length=None):
+ if motion_seqs_dir is None:
+ assert image_folder is not None
+ motion_seqs_dir, image_folder = predict_motion_seqs_from_images(image_folder, save_root, fps)
+
+ # source images
+ c2ws, intrs, bg_colors = [], [], []
+ flame_params = []
+
+ # read shape_param
+ if shape_param is None:
+ print("using driven shape params")
+ cor_flame_path = os.path.join(os.path.dirname(motion_seqs_dir),'canonical_flame_param.npz')
+ flame_p = np.load(cor_flame_path)
+ shape_param = torch.FloatTensor(flame_p['shape'])
+
+ transforms_json = os.path.join(os.path.dirname(motion_seqs_dir), f"transforms.json")
+ with open(transforms_json) as fp:
+ data = json.load(fp)
+ all_frames = data["frames"]
+ all_frames = sorted(all_frames, key=lambda x: x["flame_param_path"])
+
+ print(f"len motion_seq:{len(all_frames)}, max motion_seq_len:{max_squen_length}")
+ if(max_squen_length is not None):
+ all_frames = all_frames[:max_squen_length]
+
+ frame_ids = np.array(list(range(len(all_frames))))
+ if test_sample:
+ print("sub sample 50 frames for testing.")
+ sample_num = 50
+ frame_ids = frame_ids[np.linspace(0, frame_ids.shape[0]-1, sample_num).astype(np.int32)]
+ print("sub sample ids:", frame_ids)
+
+ teeth_bs_pth = os.path.join(os.path.dirname(motion_seqs_dir), "tracked_teeth_bs.npz")
+ if os.path.exists(teeth_bs_pth):
+ teeth_bs_lst = np.load(teeth_bs_pth)['expr_teeth']
+ else:
+ teeth_bs_lst = None
+
+ extra_dir_nm = "" if not cross_id else "_crossid"
+ for idx, frame_id in enumerate(frame_ids):
+ frame_info = all_frames[frame_id]
+ flame_path = os.path.join(os.path.dirname(motion_seqs_dir), frame_info["flame_param_path"])
+
+ if image_folder is not None:
+ file_name = os.path.splitext(os.path.basename(flame_path))[0]
+ frame_path = os.path.join(image_folder, file_name + ".png")
+ if not os.path.exists(frame_path):
+ frame_path = os.path.join(image_folder, file_name + ".jpg")
+
+ teeth_bs = teeth_bs_lst[frame_id] if teeth_bs_lst is not None else None
+ flame_param = load_flame_params(flame_path, teeth_bs)
+
+ c2w, intrinsic = _load_pose(frame_info)
+ intrinsic = scale_intrs(intrinsic, 0.5, 0.5)
+
+ c2ws.append(c2w)
+ bg_colors.append(bg_color)
+ intrs.append(intrinsic)
+ flame_params.append(flame_param)
+
+ c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
+ intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
+ bg_colors = torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3) # [N, 3]
+
+ flame_params_tmp = defaultdict(list)
+ for flame in flame_params:
+ for k, v in flame.items():
+ flame_params_tmp[k].append(v)
+ for k, v in flame_params_tmp.items():
+ flame_params_tmp[k] = torch.stack(v)
+ flame_params = flame_params_tmp
+ # TODO check different betas for same person
+ flame_params["betas"] = shape_param
+
+ if vis_motion:
+ motion_render = render_flame_mesh(flame_params, intrs, c2ws)
+ else:
+ motion_render = None
+
+ # add batch dim
+ for k, v in flame_params.items():
+ flame_params[k] = v.unsqueeze(0)
+ # print(k, flame_params[k].shape, "motion_seq")
+ c2ws = c2ws.unsqueeze(0)
+ intrs = intrs.unsqueeze(0)
+ bg_colors = bg_colors.unsqueeze(0)
+
+ motion_seqs = {}
+ motion_seqs["render_c2ws"] = c2ws
+ motion_seqs["render_intrs"] = intrs
+ motion_seqs["render_bg_colors"] = bg_colors
+ motion_seqs["flame_params"] = flame_params
+ # motion_seqs["rgbs"] = rgbs
+ motion_seqs["vis_motion_render"] = motion_render
+ return motion_seqs
+
+def prepare_gaga_motion_seqs(motion_seqs_dir, image_folder, save_root, fps,
+ bg_color, aspect_standard, enlarge_ratio,
+ render_image_res, need_mask, multiply=16,
+ vis_motion=False, shape_param=None, test_sample=False,
+ gaga_track_type="vfhq_test50_gagtrack_cano_flamescale1"
+ ):
+ if motion_seqs_dir is None:
+ assert image_folder is not None
+ motion_seqs_dir, image_folder = predict_motion_seqs_from_images(image_folder, save_root, fps)
+
+ # motion_seqs = sorted(glob.glob(os.path.join(motion_seqs_dir, "*.npz")))
+
+ # source images
+ c2ws, intrs, bg_colors = [], [], []
+ flame_params = []
+
+ # read shape_param
+ if shape_param is None:
+ print("using driven shape params")
+ cor_flame_path = os.path.join(os.path.dirname(motion_seqs_dir),'canonical_flame_param.npz')
+ flame_p = np.load(cor_flame_path)
+ shape_param = torch.FloatTensor(flame_p['shape'])
+
+ transforms_json = os.path.join(os.path.dirname(motion_seqs_dir), f"transforms.json")
+ with open(transforms_json) as fp:
+ data = json.load(fp)
+
+ uid = os.path.dirname(motion_seqs_dir).strip('/').split('/')[-1]
+ gag_optim_pth = os.path.join(f"train_data/{gaga_track_type}/", uid, "smoothed.pkl")
+ gag_flame_dict = pkl.load(open(gag_optim_pth, 'rb'))
+
+ all_frames = data["frames"]
+ all_frames = sorted(all_frames, key=lambda x: x["flame_param_path"])
+ print(f"len motion_seq:{len(all_frames)}")
+ frame_ids = np.array(list(range(len(all_frames))))
+ if test_sample:
+ print("sub sample 50 frames for testing.")
+ sample_num = 50
+ frame_ids = frame_ids[np.linspace(0, frame_ids.shape[0]-1, sample_num).astype(np.int32)]
+ print("sub sample ids:", frame_ids)
+
+ def map_flame_params(flame_param):
+ """
+ flame_param
+ ├── bbox: (4,)float32
+ ├── shapecode: (300,)float32
+ ├── expcode: (100,)float32
+ ├── posecode: (6,)float32
+ ├── neckcode: (3,)float32
+ ├── eyecode: (6,)float32
+ └── transform_matrix: (3, 4)float32
+ """
+ flame_param_tensor = {}
+ flame_param_tensor['expr'] = torch.FloatTensor(flame_param['expcode'])
+ # flame_param_tensor['rotation'] = torch.FloatTensor(flame_param['transform_matrix'])[:3, :3]
+ flame_param_tensor['rotation'] = torch.FloatTensor(flame_param['posecode'])[:3]
+ flame_param_tensor['neck_pose'] = torch.FloatTensor(flame_param.get('neckcode', np.zeros(3)))
+ flame_param_tensor['jaw_pose'] = torch.FloatTensor(flame_param['posecode'][3:])
+ flame_param_tensor['eyes_pose'] = torch.FloatTensor(flame_param['eyecode'])
+ flame_param_tensor['translation'] = torch.FloatTensor(np.zeros(3))
+ flame_param_tensor['shape'] = torch.FloatTensor(flame_param['shapecode'])
+ return flame_param_tensor
+
+ def load_pose_from_transform_mat(transform_mat):
+ c2w = torch.FloatTensor(transform_mat).clone() # w2c infact
+
+ # intrinsic is not used.
+ intrinsic = torch.eye(4)
+ intrinsic[0, 0] = 12
+ intrinsic[1, 1] = 12
+ intrinsic[0, 2] = 512 // 2
+ intrinsic[1, 2] = 512 // 2
+ intrinsic = intrinsic.float()
+
+ return c2w, intrinsic
+
+ for idx, frame_id in enumerate(frame_ids):
+ frame_info = all_frames[frame_id]
+ flame_path = os.path.join(os.path.dirname(motion_seqs_dir), frame_info["flame_param_path"])
+
+ # copy sampled images
+ frame_id = int(flame_path.split('/')[-1].split('.')[0])
+ flame_key = "%08d.png" % frame_id
+ # assert idx == frame_id, f"frame id {frame_id} should be the same as idx {idx}"
+ img_path = flame_path.replace("/flame_param/", "/images/").replace(flame_path.split("/")[-1], "%05d_00.png" % frame_id)
+ # img_path = flame_path.replace("/vfhq_test/", "/vfhq_test_tracking/").replace("/flame_param/", "/images/").replace(flame_path.split("/")[-1], flame_key)
+ gt_img = cv2.imread(img_path)
+ if gt_img.shape[0] != 512:
+ gt_img = cv2.resize(gt_img, (512, 512), interpolation=cv2.INTER_AREA)
+ new_img_fd = os.path.join(os.path.dirname(motion_seqs_dir), f"images_sampled50{gaga_track_type}")
+ if not os.path.exists(new_img_fd):
+ os.system(f"mkdir -p {new_img_fd}")
+ new_img_pth = os.path.join(new_img_fd, "%04d.png" % idx)
+ cv2.imwrite(new_img_pth, gt_img)
+
+ gag_flame_param = gag_flame_dict[flame_key]
+ flame_param = map_flame_params(gag_flame_param)
+ c2w, intrinsic = load_pose_from_transform_mat(gag_flame_param['transform_matrix'])
+
+ if shape_param is None:
+ shape_param = flame_param["shape"]
+
+ c2ws.append(c2w)
+ bg_colors.append(bg_color)
+ intrs.append(intrinsic)
+ flame_params.append(flame_param)
+
+ c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
+ intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
+ bg_colors = torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3) # [N, 3]
+
+ flame_params_tmp = defaultdict(list)
+ for flame in flame_params:
+ for k, v in flame.items():
+ flame_params_tmp[k].append(v)
+ for k, v in flame_params_tmp.items():
+ flame_params_tmp[k] = torch.stack(v)
+ flame_params = flame_params_tmp
+ # TODO check different betas for same person
+ flame_params["betas"] = shape_param
+
+ if vis_motion:
+ motion_render = render_flame_mesh_gaga19(flame_params, None, c2ws)
+ else:
+ motion_render = None
+
+ # add batch dim
+ for k, v in flame_params.items():
+ flame_params[k] = v.unsqueeze(0)
+ # print(k, flame_params[k].shape, "motion_seq")
+ c2ws = c2ws.unsqueeze(0)
+ intrs = intrs.unsqueeze(0)
+ bg_colors = bg_colors.unsqueeze(0)
+
+ motion_seqs = {}
+ motion_seqs["render_c2ws"] = c2ws
+ motion_seqs["render_intrs"] = intrs
+ motion_seqs["render_bg_colors"] = bg_colors
+ motion_seqs["flame_params"] = flame_params
+ motion_seqs["vis_motion_render"] = motion_render
+ return motion_seqs
diff --git a/lam/runners/infer/lam.py b/lam/runners/infer/lam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3acd135b46e8a89a14e421c5263517544791d594
--- /dev/null
+++ b/lam/runners/infer/lam.py
@@ -0,0 +1,611 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import traceback
+import time
+import torch
+import os
+import argparse
+import mcubes
+import trimesh
+import numpy as np
+from PIL import Image
+from glob import glob
+from omegaconf import OmegaConf
+from tqdm.auto import tqdm
+from accelerate.logging import get_logger
+
+from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image, prepare_gaga_motion_seqs
+
+
+from .base_inferrer import Inferrer
+from lam.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics
+from lam.utils.logging import configure_logger
+from lam.runners import REGISTRY_RUNNERS
+from lam.utils.video import images_to_video
+from lam.utils.hf_hub import wrap_model_hub
+from lam.models.modeling_lam import ModelLAM
+from safetensors.torch import load_file
+import moviepy.editor as mpy
+
+
+logger = get_logger(__name__)
+
+
+def parse_configs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str)
+ parser.add_argument('--infer', type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get('APP_INFER') is not None:
+ args.infer = os.environ.get('APP_INFER')
+ if os.environ.get('APP_MODEL_NAME') is not None:
+ cli_cfg.model_name = os.environ.get('APP_MODEL_NAME')
+
+ if args.config is not None:
+ cfg = OmegaConf.load(args.config)
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(cfg_train.experiment.parent, cfg_train.experiment.child, os.path.basename(cli_cfg.model_name).split('_')[-1])
+
+ cfg.save_tmp_dump = os.path.join("exps", 'save_tmp', _relative_path)
+ cfg.image_dump = os.path.join("exps", 'images', _relative_path)
+ cfg.video_dump = os.path.join("exps", 'videos', _relative_path)
+ cfg.mesh_dump = os.path.join("exps", 'meshes', _relative_path)
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ cfg.setdefault("save_tmp_dump", os.path.join("exps", cli_cfg.model_name, 'save_tmp'))
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, 'images'))
+ cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, 'videos'))
+ cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, 'meshes'))
+
+ cfg.motion_video_read_fps = 6
+ cfg.merge_with(cli_cfg)
+
+ """
+ [required]
+ model_name: str
+ image_input: str
+ export_video: bool
+ export_mesh: bool
+
+ [special]
+ source_size: int
+ render_size: int
+ video_dump: str
+ mesh_dump: str
+
+ [default]
+ render_views: int
+ render_fps: int
+ mesh_size: int
+ mesh_thres: float
+ frame_size: int
+ logger: str
+ """
+
+ cfg.setdefault('logger', 'INFO')
+
+ # assert not (args.config is not None and args.infer is not None), "Only one of config and infer should be provided"
+ assert cfg.model_name is not None, "model_name is required"
+ if not os.environ.get('APP_ENABLED', None):
+ assert cfg.image_input is not None, "image_input is required"
+ assert cfg.export_video or cfg.export_mesh, \
+ "At least one of export_video or export_mesh should be True"
+ cfg.app_enabled = False
+ else:
+ cfg.app_enabled = True
+
+ return cfg
+
+
+def count_parameters_excluding_modules(model, exclude_names=[]):
+ """
+ Counts the number of parameters in a PyTorch model, excluding specified modules by name.
+
+ Parameters:
+ - model (torch.nn.Module): The PyTorch model instance.
+ - exclude_names (list of str): List of module names to exclude from the parameter count.
+
+ Returns:
+ - int: Total number of parameters in the model, excluding specified modules.
+ """
+ total_size_bytes = 0
+ total_size_bits = 0
+ for name, module in model.named_modules():
+ # Check if the module name should be excluded
+ # print(name)
+ if any(exclude_name in name for exclude_name in exclude_names):
+ continue
+
+ # Add up the sizes of the parameters if the module is not excluded
+ for param in module.parameters():
+ total_size_bytes += param.numel() # * param.element_size()
+ if param.is_floating_point():
+ total_size_bits += param.numel() # * torch.finfo(param.dtype).bits
+ else:
+ total_size_bits += param.numel() # * torch.iinfo(param.dtype).bits
+
+ # Convert bytes to megabytes
+ total_size_mb = total_size_bytes / (1024 ** 2)
+ print("==="*16*3, f"\nTotal number of parameters: {total_size_mb}M", "\n"+"==="*16*3)
+ print(f"model size: {total_size_bits} / bit | {total_size_bits / 1e6:.2f} / MB")
+
+ return total_size_mb
+
+
+@REGISTRY_RUNNERS.register('infer.lam')
+class LAMInferrer(Inferrer):
+
+ EXP_TYPE: str = 'lam'
+
+ def __init__(self):
+ super().__init__()
+
+ self.cfg = parse_configs()
+ """
+ configure_logger(
+ stream_level=self.cfg.logger,
+ log_level=self.cfg.logger,
+ )
+ """
+
+ self.model: LAMInferrer = self._build_model(self.cfg).to(self.device)
+
+ def _build_model(self, cfg):
+ """
+ from lam.models import model_dict
+ hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE])
+ model = hf_model_cls.from_pretrained(cfg.model_name)
+ """
+ from lam.models import ModelLAM
+ model = ModelLAM(**cfg.model)
+ # total_params = count_parameters_excluding_modules(model, [])
+ # total_params = count_parameters_excluding_modules(model, ['encoder'])
+
+ resume = os.path.join(cfg.model_name, "model.safetensors")
+ print("==="*16*3)
+ print("loading pretrained weight from:", resume)
+ if resume.endswith('safetensors'):
+ ckpt = load_file(resume, device='cpu')
+ else:
+ ckpt = torch.load(resume, map_location='cpu')
+ state_dict = model.state_dict()
+ for k, v in ckpt.items():
+ if k in state_dict:
+ if state_dict[k].shape == v.shape:
+ state_dict[k].copy_(v)
+ else:
+ print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.")
+ else:
+ print(f"WARN] unexpected param {k}: {v.shape}")
+ print("finish loading pretrained weight from:", resume)
+ print("==="*16*3)
+ return model
+
+ def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')):
+ # return: (N, D_cam_raw)
+ canonical_camera_extrinsics = torch.tensor([[
+ [1, 0, 0, 0],
+ [0, 0, -1, -dist_to_center],
+ [0, 1, 0, 0],
+ ]], dtype=torch.float32, device=device)
+ canonical_camera_intrinsics = create_intrinsics(
+ f=0.75,
+ c=0.5,
+ device=device,
+ ).unsqueeze(0)
+ source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
+ return source_camera.repeat(batch_size, 1)
+
+ def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')):
+ # return: (N, M, D_cam_render)
+ render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device)
+ render_camera_intrinsics = create_intrinsics(
+ f=0.75,
+ c=0.5,
+ device=device,
+ ).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1)
+ render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics)
+ return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
+
+ def infer_planes(self, image: torch.Tensor, source_cam_dist: float):
+ N = image.shape[0]
+ source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device)
+ planes = self.model.forward_planes(image, source_camera)
+ assert N == planes.shape[0]
+ return planes
+
+ def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str):
+ N = planes.shape[0]
+ render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device)
+ render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device)
+ render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size
+ render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 0. # 1.
+
+ frames = []
+ for i in range(0, render_cameras.shape[1], frame_size):
+ frames.append(
+ self.model.synthesizer(
+ planes=planes,
+ cameras=render_cameras[:, i:i+frame_size],
+ anchors=render_anchors[:, i:i+frame_size],
+ resolutions=render_resolutions[:, i:i+frame_size],
+ bg_colors=render_bg_colors[:, i:i+frame_size],
+ region_size=render_size,
+ )
+ )
+ # merge frames
+ frames = {
+ k: torch.cat([r[k] for r in frames], dim=1)
+ for k in frames[0].keys()
+ }
+ # dump
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
+ for k, v in frames.items():
+ if k == 'images_rgb':
+ images_to_video(
+ images=v[0],
+ output_path=dump_video_path,
+ fps=render_fps,
+ gradio_codec=self.cfg.app_enabled,
+ )
+
+ def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str):
+ grid_out = self.model.synthesizer.forward_grid(
+ planes=planes,
+ grid_size=mesh_size,
+ )
+
+ vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
+ vtx = vtx / (mesh_size - 1) * 2 - 1
+
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
+ vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
+
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
+
+ # dump
+ os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True)
+ mesh.export(dump_mesh_path)
+
+ def save_imgs_2_video(self, imgs, v_pth, fps):
+ img_lst = [imgs[i] for i in range(imgs.shape[0])]
+ # Convert the list of NumPy arrays to a list of ImageClip objects
+ clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame
+
+ # Concatenate the ImageClips into a single VideoClip
+ video = mpy.concatenate_videoclips(clips, method="compose")
+
+ # Write the VideoClip to a file
+ video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example
+
+ def infer_single(self, image_path: str,
+ motion_seqs_dir,
+ motion_img_dir,
+ motion_video_read_fps,
+ export_video: bool,
+ export_mesh: bool,
+ dump_tmp_dir:str, # require by extracting motion seq from video, to save some results
+ dump_image_dir:str,
+ dump_video_path: str,
+ dump_mesh_path: str,
+ gaga_track_type: str):
+ source_size = self.cfg.source_size
+ render_size = self.cfg.render_size
+ # render_views = self.cfg.render_views
+ render_fps = self.cfg.render_fps
+ # mesh_size = self.cfg.mesh_size
+ # mesh_thres = self.cfg.mesh_thres
+ # frame_size = self.cfg.frame_size
+ # source_cam_dist = self.cfg.source_cam_dist if source_cam_dist is None else source_cam_dist
+ aspect_standard = 1.0/1.0
+ motion_img_need_mask = self.cfg.get("motion_img_need_mask", False) # False
+ vis_motion = self.cfg.get("vis_motion", False) # False
+ save_ply = self.cfg.get("save_ply", False) # False
+ save_img = self.cfg.get("save_img", False) # False
+ # mask_path = image_path.replace("/images/", "/mask/").replace(".png", ".jpg")
+ rendered_bg = 1.
+ ref_bg = 1.
+ mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png")
+ if ref_bg < 1.:
+ if "VFHQ_TEST" in image_path:
+ mask_path = image_path.replace("/VFHQ_TEST/", "/mask/").replace("/images/", "/mask/").replace(".png", ".jpg")
+ else:
+ mask_path = image_path.replace("/vfhq_test_nooffset_export/", "/mask/").replace("/images/", "/mask/").replace(".png", ".jpg")
+ if not os.path.exists(mask_path):
+ print("Warning: Mask path not exists:", mask_path)
+ mask_path = None
+ else:
+ print("load mask from:", mask_path)
+
+ # prepare reference image
+ if "hdtf" in image_path:
+ uid = image_path.split('/')[-3]
+ split0 = uid.replace(uid.split('_')[-1], '0')
+ print("==="*16*3, "\n"+image_path, uid, split0)
+ image_path = image_path.replace(uid, split0)
+ mask_path = mask_path.replace(uid, split0)
+ print(image_path, "\n"+"==="*16*3)
+ print(mask_path, "\n"+"==="*16*3)
+ if hasattr(self.cfg.model, "use_albedo_input") and (self.cfg.model.get("use_albedo_input", False)):
+ image_path = image_path.replace("/images/", "/images_hydelight/")
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=ref_bg,
+ max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0],
+ render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True)
+ # save masked image for vis
+ save_ref_img_path = os.path.join(dump_tmp_dir, "refer_" + os.path.basename(image_path))
+ vis_ref_img = (image[0].permute(1, 2 ,0).cpu().detach().numpy() * 255).astype(np.uint8)
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
+ # prepare motion seq
+ test_sample=self.cfg.get("test_sample", True)
+ # test_sample=True
+ if gaga_track_type == "":
+ print("==="*16*3, "\nuse vhap tracked results!", "\n"+"==="*16*3)
+ src = image_path.split('/')[-3]
+ driven = motion_seqs_dir.split('/')[-2]
+ src_driven = [src, driven]
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, motion_img_dir, save_root=dump_tmp_dir, fps=motion_video_read_fps,
+ bg_color=rendered_bg, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
+ render_image_res=render_size, multiply=16,
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
+ shape_param=shape_param, test_sample=test_sample, cross_id=self.cfg.get("cross_id", False), src_driven=src_driven)
+ else:
+ print("==="*16*3, "\nuse gaga tracked results:", gaga_track_type, "\n"+"==="*16*3)
+ motion_seq = prepare_gaga_motion_seqs(motion_seqs_dir, motion_img_dir, save_root=dump_tmp_dir, fps=motion_video_read_fps,
+ bg_color=rendered_bg, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
+ render_image_res=render_size, multiply=16,
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
+ shape_param=shape_param, test_sample=test_sample, gaga_track_type=gaga_track_type)
+
+ # return
+
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
+ # print(motion_seq["flame_params"].keys())
+ start_time = time.time()
+ device="cuda"
+ dtype=torch.float32
+ # dtype=torch.bfloat16
+ self.model.to(dtype)
+ print("start to inference...................")
+ with torch.no_grad():
+ # TODO check device and dtype
+ res = self.model.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
+ render_c2ws=motion_seq["render_c2ws"].to(device),
+ render_intrs=motion_seq["render_intrs"].to(device),
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
+ flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()})
+
+ print(f"time elapsed: {time.time() - start_time}")
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
+ only_pred = rgb
+ if vis_motion:
+ # print(rgb.shape, motion_seq["vis_motion_render"].shape)
+ import cv2
+ vis_ref_img = np.tile(cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :], (rgb.shape[0], 1, 1, 1))
+ blend_ratio = 0.7
+ blend_res = ((1 - blend_ratio) * rgb + blend_ratio * motion_seq["vis_motion_render"]).astype(np.uint8)
+ # rgb = np.concatenate([rgb, motion_seq["vis_motion_render"], blend_res, vis_ref_img], axis=2)
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
+
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
+ # images_to_video(rgb, output_path=dump_video_path, fps=render_fps, gradio_codec=False, verbose=True)
+ self.save_imgs_2_video(rgb, dump_video_path, render_fps)
+ if save_img and dump_image_dir is not None:
+ for i in range(rgb.shape[0]):
+ save_file = os.path.join(dump_image_dir, f"{i:04d}.png")
+ Image.fromarray(only_pred[i]).save(save_file)
+ if save_ply and dump_mesh_path is not None:
+ res["3dgs"][i][0][0].save_ply(os.path.join(dump_image_dir, f"{i:04d}.ply"))
+
+ dump_cano_dir = "./exps/cano_gs/"
+ if not os.path.exists(dump_cano_dir):
+ os.system(f"mkdir -p {dump_cano_dir}")
+ cano_ply_pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + ".ply")
+ # res['cano_gs_lst'][0].save_ply(cano_ply_pth, rgb2sh=True, offset2xyz=False)
+ # res['cano_gs_lst'][0].save_ply(cano_ply_pth, rgb2sh=True, offset2xyz=False, albedo2rgb=True)
+ cano_ply_pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + "_gs_offset.ply")
+ res['cano_gs_lst'][0].save_ply(cano_ply_pth, rgb2sh=False, offset2xyz=True, albedo2rgb=False)
+ # res['cano_gs_lst'][0].save_ply(cano_ply_pth, rgb2sh=False, offset2xyz=True)
+
+ def save_color_points(points, colors, sv_pth, sv_fd="debug_vis/dataloader/"):
+ points = points.squeeze().detach().cpu().numpy()
+ colors = colors.squeeze().detach().cpu().numpy()
+ sv_pth = os.path.join(sv_fd, sv_pth)
+ if not os.path.exists(sv_fd):
+ os.system(f"mkdir -p {sv_fd}")
+ with open(sv_pth, 'w') as of:
+ for point, color in zip(points, colors):
+ print('v', point[0], point[1], point[2], color[0], color[1], color[2], file=of)
+
+ # save canonical color point clouds
+ save_color_points(res['cano_gs_lst'][0].xyz, res["cano_gs_lst"][0].shs[:, 0, :], "framework_img.obj", sv_fd=dump_cano_dir)
+
+ # Export the template mesh to an OBJ file
+ import trimesh
+ vtxs = res['cano_gs_lst'][0].xyz - res['cano_gs_lst'][0].offset
+ vtxs = vtxs.detach().cpu().numpy()
+ faces = self.model.renderer.flame_model.faces.detach().cpu().numpy()
+ mesh = trimesh.Trimesh(vertices=vtxs, faces=faces)
+ mesh.export(os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + '_shaped_mesh.obj'))
+
+ # Export textured deformed mesh
+ import lam.models.rendering.utils.mesh_utils as mesh_utils
+ vtxs = res['cano_gs_lst'][0].xyz.detach().cpu()
+ faces = self.model.renderer.flame_model.faces.detach().cpu()
+ colors = res['cano_gs_lst'][0].shs.squeeze(1).detach().cpu()
+ pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + '_textured_mesh.obj')
+ print("Save textured mesh to:", pth)
+ mesh_utils.save_obj(pth, vtxs, faces, textures=colors, texture_type="vertex")
+
+ # if dum_mesh_path is not None:
+ # for idx, gs in enumerate(res["3dgs"]):
+ # gs.save_ply(f"{:04d}.ply")
+
+ def infer(self):
+ image_paths = []
+ # hard code
+ if os.path.isfile(self.cfg.image_input):
+ omit_prefix = os.path.dirname(self.cfg.image_input)
+ image_paths = [self.cfg.image_input]
+ else:
+ # ids = sorted(os.listdir(self.cfg.image_input))
+ # image_paths = [os.path.join(self.cfg.image_input, e, "images/00000_00.png") for e in ids]
+ image_paths = glob(os.path.join(self.cfg.image_input, "*.jpg"))
+ omit_prefix = self.cfg.image_input
+
+ """
+ # image_paths = glob("train_data/demo_export/DEMOVIDEO/*/images/00000_00.png")
+ image_paths = glob("train_data/vfhq_test/VFHQ_TEST/Clip+G0DGRma_p48+P0+C0+F11208-11383/images/00000_00.png")
+ image_paths = glob("train_data/SIDE_FACE/*/images/00000_00.png")
+ image_paths = glob("train_data/vfhq_test/VFHQ_TEST/*/images/00000_00.png")
+
+ import json
+ # uids = json.load(open("./train_data/vfhq_vhap/selected_id.json", 'r'))["self_id"]
+ # image_paths = [os.path.join("train_data/vfhq_test/VFHQ_TEST/", uid, "images/00000_00.png") for uid in uids]
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png")
+ # image_paths = glob("train_data/nersemble_vhap/export/017_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png")
+ # image_paths = glob("train_data/nersemble_vhap/export/374_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png")
+ image_paths = glob("train_data/nersemble_vhap/export/375_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png")
+
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png")
+ """
+
+ # image_paths = glob("train_data/hdtf_test/export/*/images/00000_00.png")
+
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") # [0:1]
+
+ # image_paths = glob("train_data/vfhq_test/VFHQ_TEST/*/images/00000_00.png")
+ print(len(image_paths), image_paths)
+
+ # image_paths = ["train_data/vfhq_test/VFHQ_TEST/Clip+VjvX4tzzlbo+P2+C0+F5669-5935/images/00000_00.png"]
+ # image_paths = ["train_data/vfhq_test/VFHQ_TEST/Clip+KSF3tPr9zAk+P0+C2+F8769-8880/images/00000_00.png"]
+ image_paths = ["train_data/vfhq_test/VFHQ_TEST/Clip+G0DGRma_p48+P0+C0+F11208-11383/images/00000_00.png"]
+
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png")
+
+ uids = ['Clip+1qf8dZpLED0+P2+C1+F5731-5855', 'Clip+8vcxTHoDadk+P3+C0+F27918-28036', 'Clip+gsHu2fb3aj0+P0+C0+F17563-17742']
+ image_paths = ["train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png".replace("*", item) for item in uids]
+
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png")
+
+ image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png")
+
+ image_paths = glob("train_data/test_2w_cases/*/images/00000_00.png")
+
+ # if os.path.isfile(self.cfg.image_input):
+ # omit_prefix = os.path.dirname(self.cfg.image_input)
+ # image_paths.append(self.cfg.image_input)
+ # else:
+ # omit_prefix = self.cfg.image_input
+ # suffixes = ('.jpg', '.jpeg', '.png', '.webp')
+ # for root, dirs, files in os.walk(self.cfg.image_input):
+ # for file in files:
+ # if file.endswith(suffixes):
+ # image_paths.append(os.path.join(root, file))
+ # image_paths.sort()
+
+ # alloc to each DDP worker
+ # image_paths = image_paths[self.accelerator.process_index::self.accelerator.num_processes]
+ if "hdtf" in image_paths[0]:
+ image_paths = image_paths[self.cfg.get("rank", 0)::self.cfg.get("nodes", 1)]
+
+ gaga_track_type = self.cfg.get("gaga_track_type", "")
+ if gaga_track_type is None:
+ gaga_track_type = ""
+ print("==="*16*3, "\nUse gaga_track_type:", gaga_track_type, "\n"+"==="*16*3)
+
+ if self.cfg.get("cross_id", False):
+ import json
+ cross_id_lst = json.load(open("train_data/Cross-identity-info.json", 'r'))
+ src2driven = {item["src"]: item["driven"] for item in cross_id_lst}
+
+ for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process):
+ try:
+ # self.cfg.motion_seqs_dir = image_path.replace("/images/00000_00.png", "/flame_param")
+ motion_seqs_dir = self.cfg.motion_seqs_dir
+ if "VFHQ_TEST" in image_path or "vfhq_test_nooffset_export" in image_path or "hdtf" in image_path:
+ motion_seqs_dir = os.path.join(*image_path.split('/')[:-2], "flame_param")
+ # read shape_param
+ if self.cfg.get("cross_id", False):
+ src = motion_seqs_dir.split('/')[-2]
+ driven = src2driven[src]
+ motion_seqs_dir = motion_seqs_dir.replace(src, driven)
+
+ print("motion_seqs_dir:", motion_seqs_dir)
+ # prepare dump paths
+ image_name = os.path.basename(image_path)
+ uid = image_name.split('.')[0]
+ subdir_path = os.path.dirname(image_path).replace(omit_prefix, '')
+ subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path
+ # hard code
+ subdir_path = gaga_track_type
+ if self.cfg.get("cross_id", False):
+ subdir_path = "cross_id"
+ print("==="*16*3, "\n"+ "subdir_path:", subdir_path, "\n"+"==="*16*3)
+ uid = os.path.basename(os.path.dirname(os.path.dirname(image_path)))
+ print("subdir_path and uid:", subdir_path, uid)
+ dump_video_path = os.path.join(
+ self.cfg.video_dump,
+ subdir_path,
+ f'{uid}.mp4',
+ )
+ dump_image_dir = os.path.join(
+ self.cfg.image_dump,
+ subdir_path,
+ f'{uid}'
+ )
+ dump_tmp_dir = os.path.join(
+ self.cfg.image_dump,
+ subdir_path,
+ "tmp_res"
+ )
+ dump_mesh_path = os.path.join(
+ self.cfg.mesh_dump,
+ subdir_path,
+ # f'{uid}.ply',
+ )
+ os.makedirs(dump_image_dir, exist_ok=True)
+ os.makedirs(dump_tmp_dir, exist_ok=True)
+ os.makedirs(dump_mesh_path, exist_ok=True)
+
+ # if os.path.exists(dump_video_path):
+ # print(f"skip:{image_path}")
+ # continue
+
+ self.infer_single(
+ image_path,
+ motion_seqs_dir=motion_seqs_dir,
+ motion_img_dir=self.cfg.motion_img_dir,
+ motion_video_read_fps=self.cfg.motion_video_read_fps,
+ export_video=self.cfg.export_video,
+ export_mesh=self.cfg.export_mesh,
+ dump_tmp_dir=dump_tmp_dir,
+ dump_image_dir=dump_image_dir,
+ dump_video_path=dump_video_path,
+ dump_mesh_path=dump_mesh_path,
+ gaga_track_type=gaga_track_type
+ )
+ except:
+ traceback.print_exc()
diff --git a/lam/runners/infer/utils.py b/lam/runners/infer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32643aa173244f6e93d3f546dc297bcebdd016e3
--- /dev/null
+++ b/lam/runners/infer/utils.py
@@ -0,0 +1,317 @@
+from collections import defaultdict
+import glob
+import os
+import json
+import numpy as np
+from PIL import Image
+import cv2
+import torch
+import decord
+
+
+def scale_intrs(intrs, ratio_x, ratio_y):
+ if len(intrs.shape) >= 3:
+ intrs[:, 0] = intrs[:, 0] * ratio_x
+ intrs[:, 1] = intrs[:, 1] * ratio_y
+ else:
+ intrs[0] = intrs[0] * ratio_x
+ intrs[1] = intrs[1] * ratio_y
+ return intrs
+
+def calc_new_tgt_size(cur_hw, tgt_size, multiply):
+ ratio = tgt_size / min(cur_hw)
+ tgt_size = int(ratio * cur_hw[0]), int(ratio * cur_hw[1])
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+def calc_new_tgt_size_by_aspect(cur_hw, aspect_standard, tgt_size, multiply):
+ assert abs(cur_hw[0] / cur_hw[1] - aspect_standard) < 0.03
+ tgt_size = tgt_size * aspect_standard, tgt_size
+ tgt_size = int(tgt_size[0] / multiply) * multiply, int(tgt_size[1] / multiply) * multiply
+ ratio_y, ratio_x = tgt_size[0] / cur_hw[0], tgt_size[1] / cur_hw[1]
+ return tgt_size, ratio_y, ratio_x
+
+def _load_pose(pose):
+ intrinsic = torch.eye(4)
+ intrinsic[0, 0] = pose["focal"][0]
+ intrinsic[1, 1] = pose["focal"][1]
+ intrinsic[0, 2] = pose["princpt"][0]
+ intrinsic[1, 2] = pose["princpt"][1]
+ intrinsic = intrinsic.float()
+
+ c2w = torch.eye(4)
+ # c2w[:3, :3] = torch.tensor(pose["R"])
+ # c2w[3, :3] = torch.tensor(pose["t"])
+ c2w = c2w.float()
+
+ return c2w, intrinsic
+
+
+def img_center_padding(img_np, pad_ratio):
+
+ ori_w, ori_h = img_np.shape[:2]
+
+ w = round((1 + pad_ratio) * ori_w)
+ h = round((1 + pad_ratio) * ori_h)
+
+ if len(img_np.shape) > 2:
+ img_pad_np = np.zeros((w, h, img_np.shape[2]), dtype=np.uint8)
+ else:
+ img_pad_np = np.zeros((w, h), dtype=np.uint8)
+ offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2
+ img_pad_np[offset_h: offset_h + img_np.shape[0]:, offset_w: offset_w + img_np.shape[1]] = img_np
+
+ return img_pad_np
+
+
+def resize_image_keepaspect_np(img, max_tgt_size):
+ """
+ similar to ImageOps.contain(img_pil, (img_size, img_size)) # keep the same aspect ratio
+ """
+ h, w = img.shape[:2]
+ ratio = max_tgt_size / max(h, w)
+ new_h, new_w = round(h * ratio), round(w * ratio)
+ return cv2.resize(img, dsize=(new_w, new_h), interpolation=cv2.INTER_AREA)
+
+
+def center_crop_according_to_mask(img, mask, aspect_standard, enlarge_ratio):
+ """
+ img: [H, W, 3]
+ mask: [H, W]
+ """
+ ys, xs = np.where(mask > 0)
+
+ if len(xs) == 0 or len(ys) == 0:
+ raise Exception("empty mask")
+
+ x_min = np.min(xs)
+ x_max = np.max(xs)
+ y_min = np.min(ys)
+ y_max = np.max(ys)
+
+ center_x, center_y = img.shape[1]//2, img.shape[0]//2
+
+ half_w = max(abs(center_x - x_min), abs(center_x - x_max))
+ half_h = max(abs(center_y - y_min), abs(center_y - y_max))
+ half_w_raw = half_w
+ half_h_raw = half_h
+ aspect = half_h / half_w
+
+ if aspect >= aspect_standard:
+ half_w = round(half_h / aspect_standard)
+ else:
+ half_h = round(half_w * aspect_standard)
+
+ if half_h > center_y:
+ half_w = round(half_h_raw / aspect_standard)
+ half_h = half_h_raw
+ if half_w > center_x:
+ half_h = round(half_w_raw * aspect_standard)
+ half_w = half_w_raw
+
+ if abs(enlarge_ratio[0] - 1) > 0.01 or abs(enlarge_ratio[1] - 1) > 0.01:
+ enlarge_ratio_min, enlarge_ratio_max = enlarge_ratio
+ enlarge_ratio_max_real = min(center_y / half_h, center_x / half_w)
+ enlarge_ratio_max = min(enlarge_ratio_max_real, enlarge_ratio_max)
+ enlarge_ratio_min = min(enlarge_ratio_max_real, enlarge_ratio_min)
+ enlarge_ratio_cur = np.random.rand() * (enlarge_ratio_max - enlarge_ratio_min) + enlarge_ratio_min
+ half_h, half_w = round(enlarge_ratio_cur * half_h), round(enlarge_ratio_cur * half_w)
+
+ assert half_h <= center_y
+ assert half_w <= center_x
+ assert abs(half_h / half_w - aspect_standard) < 0.03
+
+ offset_x = center_x - half_w
+ offset_y = center_y - half_h
+
+ new_img = img[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+ new_mask = mask[offset_y: offset_y + 2*half_h, offset_x: offset_x + 2*half_w]
+
+ return new_img, new_mask, offset_x, offset_y
+
+
+def preprocess_image(rgb_path, mask_path, intr, pad_ratio, bg_color,
+ max_tgt_size, aspect_standard, enlarge_ratio,
+ render_tgt_size, multiply, need_mask=True):
+ rgb = np.array(Image.open(rgb_path))
+ rgb_raw = rgb.copy()
+ if pad_ratio > 0:
+ rgb = img_center_padding(rgb, pad_ratio)
+
+ rgb = rgb / 255.0
+ if need_mask:
+ if rgb.shape[2] < 4:
+ if mask_path is not None:
+ mask = np.array(Image.open(mask_path))
+ else:
+ from rembg import remove
+ mask = remove(rgb_raw[:, :, (2, 1, 0)])[:, :, -1] # np require [bgr]
+ print("rmbg mask: ", mask.min(), mask.max(), mask.shape)
+ if pad_ratio > 0:
+ mask = img_center_padding(mask, pad_ratio)
+ mask = mask / 255.0
+ else:
+ # rgb: [H, W, 4]
+ assert rgb.shape[2] == 4
+ mask = rgb[:, :, 3] # [H, W]
+ else:
+ # just placeholder
+ mask = np.ones_like(rgb[:, :, 0])
+
+ mask = (mask > 0.5).astype(np.float32)
+ rgb = rgb[:, :, :3] * mask[:, :, None] + bg_color * (1 - mask[:, :, None])
+
+ # resize to specific size require by preprocessor of flame-estimator.
+ rgb = resize_image_keepaspect_np(rgb, max_tgt_size)
+ mask = resize_image_keepaspect_np(mask, max_tgt_size)
+
+ # crop image to enlarge human area.
+ rgb, mask, offset_x, offset_y = center_crop_according_to_mask(rgb, mask, aspect_standard, enlarge_ratio)
+ if intr is not None:
+ intr[0, 2] -= offset_x
+ intr[1, 2] -= offset_y
+
+ # resize to render_tgt_size for training
+ tgt_hw_size, ratio_y, ratio_x = calc_new_tgt_size_by_aspect(cur_hw=rgb.shape[:2],
+ aspect_standard=aspect_standard,
+ tgt_size=render_tgt_size, multiply=multiply)
+ rgb = cv2.resize(rgb, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=cv2.INTER_AREA)
+ mask = cv2.resize(mask, dsize=(tgt_hw_size[1], tgt_hw_size[0]), interpolation=cv2.INTER_AREA)
+
+ if intr is not None:
+ intr = scale_intrs(intr, ratio_x=ratio_x, ratio_y=ratio_y)
+ assert abs(intr[0, 2] * 2 - rgb.shape[1]) < 2.5, f"{intr[0, 2] * 2}, {rgb.shape[1]}"
+ assert abs(intr[1, 2] * 2 - rgb.shape[0]) < 2.5, f"{intr[1, 2] * 2}, {rgb.shape[0]}"
+ intr[0, 2] = rgb.shape[1] // 2
+ intr[1, 2] = rgb.shape[0] // 2
+
+ rgb = torch.from_numpy(rgb).float().permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
+ mask = torch.from_numpy(mask[:, :, None]).float().permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
+ return rgb, mask, intr
+
+
+def extract_imgs_from_video(video_file, save_root, fps):
+ print(f"extract_imgs_from_video:{video_file}")
+ vr = decord.VideoReader(video_file)
+ for i in range(0, len(vr), fps):
+ frame = vr[i].asnumpy()
+ save_path = os.path.join(save_root, f"{i:05d}.jpg")
+ cv2.imwrite(save_path, frame[:, :, (2, 1, 0)])
+
+def predict_motion_seqs_from_images(image_folder:str, save_root, fps=6):
+ id_name = os.path.splitext(os.path.basename(image_folder))[0]
+ if os.path.isfile(image_folder) and (image_folder.endswith("mp4") or image_folder.endswith("move")):
+ save_frame_root = os.path.join(save_root, "extracted_frames", id_name)
+ if not os.path.exists(save_frame_root):
+ os.makedirs(save_frame_root, exist_ok=True)
+ extract_imgs_from_video(video_file=image_folder, save_root=save_frame_root, fps=fps)
+ else:
+ print("skip extract_imgs_from_video......")
+ image_folder = save_frame_root
+
+ image_folder_abspath = os.path.abspath(image_folder)
+ print(f"predict motion seq:{image_folder_abspath}")
+ save_flame_root = image_folder + "_flame_params_mhmr"
+ if not os.path.exists(save_flame_root):
+ cmd = f"cd thirdparty/multi-hmr && python infer_batch.py --data_root {image_folder_abspath} --out_folder {image_folder_abspath} --crop_head --crop_hand --pad_ratio 0.2 --smplify"
+ os.system(cmd)
+ else:
+ print("skip predict flame.........")
+ return save_flame_root, image_folder
+
+
+def prepare_motion_seqs(motion_seqs_dir, image_folder, save_root, fps,
+ bg_color, aspect_standard, enlarge_ratio,
+ render_image_res, need_mask, multiply=16,
+ vis_motion=False):
+ if motion_seqs_dir is None:
+ assert image_folder is not None
+ motion_seqs_dir, image_folder = predict_motion_seqs_from_images(image_folder, save_root, fps)
+
+ motion_seqs = sorted(glob.glob(os.path.join(motion_seqs_dir, "*.json")))
+
+ # source images
+ c2ws, intrs, rgbs, bg_colors, masks = [], [], [], [], []
+ flame_params = []
+ shape_param = None
+
+ for idx, flame_path in enumerate(motion_seqs):
+ if image_folder is not None:
+ file_name = os.path.splitext(os.path.basename(flame_path))[0]
+ frame_path = os.path.join(image_folder, file_name + ".png")
+ if not os.path.exists(frame_path):
+ frame_path = os.path.join(image_folder, file_name + ".jpg")
+
+ with open(flame_path) as f:
+ flame_raw_data = json.load(f)
+ flame_param = {k: torch.FloatTensor(v) for k, v in flame_raw_data.items() if "pad_ratio" not in k}
+
+ if idx == 0:
+ shape_param = flame_param["betas"]
+
+ c2w, intrinsic = _load_pose(flame_param)
+ intrinsic_raw = intrinsic.clone()
+ if image_folder is not None:
+ rgb, mask, intrinsic = preprocess_image(frame_path, mask_path=None,
+ need_mask=need_mask,
+ bg_color=bg_color,
+ pad_ratio=float(flame_raw_data["pad_ratio"]),
+ max_tgt_size=int(flame_param["img_size_wh"][0]),
+ aspect_standard=aspect_standard,
+ enlarge_ratio=enlarge_ratio,
+ render_tgt_size=render_image_res,
+ multiply=multiply,
+ intr=intrinsic)
+ rgbs.append(rgb)
+ masks.append(mask)
+
+ c2ws.append(c2w)
+ bg_colors.append(bg_color)
+ intrs.append(intrinsic)
+ # intrs.append(intrinsic_raw)
+ flame_params.append(flame_param)
+
+ c2ws = torch.stack(c2ws, dim=0) # [N, 4, 4]
+ intrs = torch.stack(intrs, dim=0) # [N, 4, 4]
+ bg_colors = torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1).repeat(1, 3) # [N, 3]
+
+ if len(rgbs) > 0:
+ rgbs = torch.cat(rgbs, dim=0) # [N, 3, H, W]
+ # masks = torch.cat(masks, dim=0) # [N, 1, H, W]
+
+ flame_params_tmp = defaultdict(list)
+ for flame in flame_params:
+ for k, v in flame.items():
+ flame_params_tmp[k].append(v)
+ for k, v in flame_params_tmp.items():
+ flame_params_tmp[k] = torch.stack(v) # [Nv, xx, xx]
+ flame_params = flame_params_tmp
+ # TODO check different betas for same person
+ flame_params["betas"] = shape_param
+
+ if vis_motion:
+ motion_render = render_flame_mesh(flame_params, intrs)
+ else:
+ motion_render = None
+
+ # add batch dim
+ for k, v in flame_params.items():
+ flame_params[k] = v.unsqueeze(0)
+ # print(k, flame_params[k].shape, "motion_seq")
+ c2ws = c2ws.unsqueeze(0)
+ intrs = intrs.unsqueeze(0)
+ bg_colors = bg_colors.unsqueeze(0)
+ if len(rgbs) > 0:
+ rgbs = rgbs.unsqueeze(0)
+ # print(f"c2ws:{c2ws.shape}, intrs:{intrs.shape}, rgbs:{rgbs.shape if len(rgbs) > 0 else None}")
+
+ motion_seqs = {}
+ motion_seqs["render_c2ws"] = c2ws
+ motion_seqs["render_intrs"] = intrs
+ motion_seqs["render_bg_colors"] = bg_colors
+ motion_seqs["flame_params"] = flame_params
+ motion_seqs["rgbs"] = rgbs
+ motion_seqs["vis_motion_render"] = motion_render
+
+ return motion_seqs
\ No newline at end of file
diff --git a/lam/runners/train/__init__.py b/lam/runners/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceaab06a18ecc9ae602aeb2001472ade686d4e4d
--- /dev/null
+++ b/lam/runners/train/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .lam import LAMTrainer
diff --git a/lam/runners/train/base_trainer.py b/lam/runners/train/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..840bf451123d7bddb047aeb50ce8a463545121d3
--- /dev/null
+++ b/lam/runners/train/base_trainer.py
@@ -0,0 +1,461 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import traceback
+import os
+import time
+import math
+import argparse
+import shutil
+import torch
+import safetensors
+from omegaconf import OmegaConf
+from abc import abstractmethod
+from contextlib import contextmanager
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+import cv2
+import numpy as np
+
+from lam.utils.logging import configure_logger
+from lam.utils.compile import configure_dynamo
+from lam.runners.abstract import Runner
+
+
+logger = get_logger(__name__)
+
+
+def parse_configs():
+ # Define argparse arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, default='./assets/config.yaml')
+ parser.add_argument('--resume', type=str, default='')
+ args, unknown = parser.parse_known_args()
+
+ # Load configuration file
+ cfg = OmegaConf.load(args.config)
+
+ # Override with command-line arguments
+ cli_cfg = OmegaConf.from_cli(unknown)
+ cfg = OmegaConf.merge(cfg, cli_cfg)
+ if len(args.resume) > 0:
+ cfg.train.resume = args.resume
+
+ return cfg
+
+
+class Trainer(Runner):
+
+ def __init__(self):
+ super().__init__()
+
+ self.cfg = parse_configs()
+ self.has_disc = self.cfg.model.has_disc if hasattr(self.cfg.model, "has_disc") else False
+
+ self.timestamp = time.strftime("%Y%m%d-%H%M%S")
+
+ self.accelerator = Accelerator(
+ mixed_precision=self.cfg.train.mixed_precision,
+ gradient_accumulation_steps=self.cfg.train.accum_steps,
+ log_with=tuple(self.cfg.logger.trackers),
+ project_config=ProjectConfiguration(
+ logging_dir=self.cfg.logger.tracker_root,
+ ),
+ use_seedable_sampler=True,
+ kwargs_handlers=[
+ DistributedDataParallelKwargs(
+ find_unused_parameters=self.cfg.train.find_unused_parameters,
+ ),
+ ],
+ )
+
+ self.weight_dtype = self.get_weight_dtype()
+ print(f"weight_dtype:{self.weight_dtype}")
+
+ set_seed(self.cfg.experiment.seed, device_specific=True)
+ with self.accelerator.main_process_first():
+ configure_logger(
+ stream_level=self.cfg.logger.stream_level,
+ log_level=self.cfg.logger.log_level,
+ file_path=os.path.join(
+ self.cfg.logger.log_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ f"{self.timestamp}.log",
+ ) if self.accelerator.is_main_process else None,
+ )
+ logger.info(self.accelerator.state, main_process_only=False, in_order=True)
+ configure_dynamo(dict(self.cfg.compile))
+
+ # attributes with defaults
+ self.model : torch.nn.Module = None
+ self.optimizer: torch.optim.Optimizer = None
+ self.scheduler: torch.optim.lr_scheduler.LRScheduler = None
+ self.train_loader: torch.utils.data.DataLoader = None
+ self.val_loader: torch.utils.data.DataLoader = None
+ self.N_max_global_steps: int = None
+ self.N_global_steps_per_epoch: int = None
+ self.global_step: int = 0
+ self.current_epoch: int = 0
+
+ def __enter__(self):
+ self.accelerator.init_trackers(
+ project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}",
+ )
+ self.prepare_everything()
+ self.log_inital_info()
+
+ #self.accelerator.trackers[0].logging_dir
+ self.trackers_logging_dir = f"{self.cfg.logger.tracker_root}/{self.cfg.experiment.parent}/{self.cfg.experiment.child}"
+ os.makedirs(self.trackers_logging_dir, exist_ok=True)
+
+ self.snapshot_cfg(self.cfg)
+
+ return self
+
+ def get_weight_dtype(self):
+ weight_dtype = torch.float32
+ if self.accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif self.accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ elif self.accelerator.mixed_precision == "no":
+ weight_dtype = torch.float32
+ else:
+ raise NotImplementedError
+ return weight_dtype
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.accelerator.end_training()
+
+ @staticmethod
+ def control(option: str = None, synchronized: bool = False):
+ def decorator(func):
+ def wrapper(self, *args, **kwargs):
+ if option is None or hasattr(self.accelerator, option):
+ accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func
+ result = accelerated_func(self, *args, **kwargs)
+ if synchronized:
+ self.accelerator.wait_for_everyone()
+ return result
+ else:
+ raise AttributeError(f"Accelerator has no attribute {option}")
+ return wrapper
+ return decorator
+
+ @contextmanager
+ def exec_in_order(self):
+ for rank in range(self.accelerator.num_processes):
+ try:
+ if self.accelerator.process_index == rank:
+ yield
+ finally:
+ self.accelerator.wait_for_everyone()
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @property
+ def is_distributed(self) -> bool:
+ return self.accelerator.num_processes > 1
+
+ def prepare_everything(self, is_dist_validation: bool = True):
+ # prepare with accelerator
+ if is_dist_validation:
+ if not self.has_disc:
+ self.model, self.optimizer, self.train_loader, self.val_loader = \
+ self.accelerator.prepare(
+ self.model, self.optimizer, self.train_loader, self.val_loader,
+ )
+ else:
+ self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader, self.val_loader = \
+ self.accelerator.prepare(
+ self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader, self.val_loader,
+ )
+ else:
+ if not self.has_disc:
+ self.model, self.optimizer, self.train_loader = \
+ self.accelerator.prepare(
+ self.model, self.optimizer, self.train_loader,
+ )
+ else:
+ self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader = \
+ self.accelerator.prepare(
+ self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader,
+ )
+
+ self.accelerator.register_for_checkpointing(self.scheduler)
+ if self.has_disc:
+ self.accelerator.register_for_checkpointing(self.scheduler_disc)
+ # prepare stats
+ N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps
+ self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps)
+ self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs
+ if self.cfg.train.debug_global_steps is not None:
+ logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}")
+ self.N_max_global_steps = self.cfg.train.debug_global_steps
+ print(f"======== Trainable parameters ========")
+ print(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6}M")
+ logger.info(f"======== Statistics ========")
+ logger.info(f"** N_max_global_steps: {self.N_max_global_steps}")
+ logger.info(f"** N_total_batch_size: {N_total_batch_size}")
+ logger.info(f"** N_epochs: {self.cfg.train.epochs}")
+ logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}")
+ logger.debug(f"** Prepared loader length: {len(self.train_loader)}")
+ logger.info(f"** Distributed validation: {is_dist_validation}")
+ logger.info(f"============================")
+ logger.info(f"======== Trainable parameters ========")
+ logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
+ for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children():
+ logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}")
+ logger.info(f"=====================================")
+ self.accelerator.wait_for_everyone()
+ # load checkpoint or model
+ self.load_ckpt_or_auto_resume_(self.cfg)
+ # register hooks
+ self.register_hooks()
+
+ @abstractmethod
+ def register_hooks(self):
+ pass
+
+ def auto_resume_(self, cfg, ckpt_root=None) -> bool:
+ if ckpt_root is None:
+ ckpt_root = os.path.join(
+ cfg.saver.checkpoint_root,
+ cfg.experiment.parent, cfg.experiment.child,
+ )
+ if not os.path.exists(ckpt_root):
+ return False
+ ckpt_dirs = os.listdir(ckpt_root)
+ if len(ckpt_dirs) == 0:
+ return False
+ ckpt_dirs.sort()
+ latest_ckpt = ckpt_dirs[-1]
+ latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt)
+ logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========")
+ self.accelerator.load_state(latest_ckpt_dir)
+ self.global_step = int(latest_ckpt)
+ self.current_epoch = self.global_step // self.N_global_steps_per_epoch
+ return True
+
+ def load_model_(self, cfg):
+ logger.info(f"======== Loading model from {cfg.saver.load_model} ========")
+
+ # model = self.accelerator.unwrap_model(self.model)
+ # state_dict = safetensors.torch.load_file(cfg.saver.load_model, device='cpu')
+ # state_dict.pop('pcl_embeddings.weight')
+ # model_state_dict = model.state_dict()
+ # missing, unexpected = model.load_state_dict(state_dict, strict=False)
+ # missing = set(missing)
+ # print("missing:", missing)
+ # print("unexpected:", unexpected)
+
+ try:
+ safetensors.torch.load_model(
+ self.accelerator.unwrap_model(self.model),
+ cfg.saver.load_model,
+ strict=cfg.saver.load_model_strict if hasattr(cfg.saver, "load_model_strict") else True,
+ )
+ except:
+ traceback.print_exc()
+ model = self.accelerator.unwrap_model(self.model)
+ model_state_dict = model.state_dict()
+ state_dict = safetensors.torch.load_file(cfg.saver.load_model, device='cpu')
+ for key in list(state_dict):
+ if "renderer.flame_model" in key:
+ print(f"pop:{key}, shape:{state_dict[key].shape}")
+ state_dict.pop(key)
+ if "renderer.flame_model" in key:
+ print(f"pop:{key}, shape:{state_dict[key].shape}")
+ state_dict.pop(key)
+ if "renderer.gs_net.out_layers.scaling.weight" == key:
+ if state_dict["renderer.gs_net.out_layers.scaling.weight"].shape != model_state_dict["renderer.gs_net.out_layers.scaling.weight"].shape:
+ # state_dict["renderer.gs_net.out_layers.scaling.weight"] = state_dict["renderer.gs_net.out_layers.scaling.weight"][:1]
+ # state_dict["renderer.gs_net.out_layers.scaling.bias"] = state_dict["renderer.gs_net.out_layers.scaling.bias"][:1]
+ state_dict.pop("renderer.gs_net.out_layers.scaling.weight")
+ state_dict.pop("renderer.gs_net.out_layers.scaling.bias")
+
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
+ missing = set(missing)
+ print("missing:", missing)
+ print("unexpected:", unexpected)
+
+ if self.has_disc and cfg.saver.get("load_model_disc", None) is not None:
+ safetensors.torch.load_model(
+ self.accelerator.unwrap_model(self.model_disc),
+ cfg.saver.load_model_disc,
+ strict=cfg.saver.load_model_strict if hasattr(cfg.saver, "load_model_strict") else True,
+ )
+ logger.info(f"======== Model loaded ========")
+
+ @control(synchronized=True)
+ def load_ckpt_or_auto_resume_(self, cfg):
+ # auto resume has higher priority, load model from path if auto resume is not available
+ # cfg.saver.auto_resume and cfg.saver.load_model
+
+ if hasattr(cfg.saver, "load_ckpt") and cfg.saver.load_ckpt:
+ successful_resume = self.auto_resume_(cfg, ckpt_root=cfg.saver.load_ckpt)
+ if successful_resume:
+ return
+
+ if cfg.saver.auto_resume:
+ successful_resume = self.auto_resume_(cfg)
+ if successful_resume:
+ return
+
+ if cfg.saver.load_model:
+ successful_load = self.load_model_(cfg)
+ if successful_load:
+ return
+ logger.debug(f"======== No checkpoint or model is loaded ========")
+
+
+ # @control('on_main_process', synchronized=True)
+ def _save_checkpoint(self):
+ ckpt_dir = os.path.join(
+ self.cfg.saver.checkpoint_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ f"{self.global_step:06d}",
+ )
+ self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True)
+ logger.info(f"======== Saved checkpoint at global step {self.global_step} ========")
+ # manage stratified checkpoints
+ ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir))
+ ckpt_dirs.sort()
+ max_ckpt = int(ckpt_dirs[-1])
+ ckpt_base = int(self.cfg.saver.checkpoint_keep_level)
+ ckpt_period = self.cfg.saver.checkpoint_global_steps
+ logger.debug(f"Checkpoint base: {ckpt_base}")
+ logger.debug(f"Checkpoint period: {ckpt_period}")
+ cur_order = ckpt_base ** math.floor(math.log(max_ckpt // ckpt_period, ckpt_base))
+ cur_idx = 0
+ while cur_order > 0:
+ cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base
+ while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit:
+ if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0:
+ shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx]))
+ logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}")
+ cur_idx += 1
+ cur_order //= ckpt_base
+
+ def save_checkpoint(self):
+ if self.accelerator.state.deepspeed_plugin is not None:
+ logger.info("deepspeed mode to save ckpt...............")
+ self._save_checkpoint()
+ else:
+ if self.accelerator.is_main_process:
+ self._save_checkpoint()
+
+ @control('on_main_process')
+ def snapshot_cfg(self, cfg):
+ # save_path=os.path.join(self.accelerator.trackers[0].logging_dir, "config.yaml")
+ save_path=os.path.join(self.trackers_logging_dir, "config.yaml")
+ OmegaConf.save(cfg, save_path)
+
+ @property
+ def global_step_in_epoch(self):
+ return self.global_step % self.N_global_steps_per_epoch
+
+ @abstractmethod
+ def _build_model(self):
+ pass
+
+ @abstractmethod
+ def _build_optimizer(self):
+ pass
+
+ @abstractmethod
+ def _build_scheduler(self):
+ pass
+
+ @abstractmethod
+ def _build_dataloader(self):
+ pass
+
+ @abstractmethod
+ def _build_loss_fn(self):
+ pass
+
+ @abstractmethod
+ def train(self):
+ pass
+
+ @abstractmethod
+ def evaluate(self):
+ pass
+
+ @staticmethod
+ def _get_str_progress(epoch: int = None, step: int = None):
+ if epoch is not None:
+ log_type = 'epoch'
+ log_progress = epoch
+ elif step is not None:
+ log_type = 'step'
+ log_progress = step
+ else:
+ raise ValueError('Either epoch or step must be provided')
+ return log_type, log_progress
+
+ @control('on_main_process')
+ def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs):
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ split = f'/{split}' if split else ''
+ for key, value in scalar_kwargs.items():
+ self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress)
+
+ def log_images_each_process(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
+ for tracker in self.accelerator.trackers:
+ if hasattr(tracker, 'log_images'):
+ tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {}))
+ # log_dir = tracker.logging_dir
+ log_dir = self.trackers_logging_dir
+ if log_kwargs.get("imwrite_image", True):
+ for k, v in values.items():
+ v = v[0].permute(1, 2, 0).detach().cpu().numpy()
+ save_path = os.path.join(log_dir, f"{step:05d}_{k.replace('/', '_')}.jpg")
+ # print(save_path)
+ cv2.imwrite(save_path, (v * 255).astype(np.uint8)[:, :, (2, 1, 0)])
+
+ @control('on_main_process')
+ def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
+ self.log_images_each_process(values, step, log_kwargs)
+
+
+ @control('on_main_process')
+ def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []):
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ assert self.optimizer is not None, 'Optimizer is not initialized'
+ if not attrs:
+ logger.warning('No optimizer attributes are provided, nothing will be logged')
+ if not group_ids:
+ logger.warning('No optimizer group ids are provided, nothing will be logged')
+ for attr in attrs:
+ assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}'
+ for group_id in group_ids:
+ self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress)
+
+ @control('on_main_process')
+ def log_inital_info(self):
+ assert self.model is not None, 'Model is not initialized'
+ assert self.optimizer is not None, 'Optimizer is not initialized'
+ assert self.scheduler is not None, 'Scheduler is not initialized'
+ self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"})
+ self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"})
+ self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"})
+ self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"})
+
+ def run(self):
+ self.train()
diff --git a/lam/runners/train/lam.py b/lam/runners/train/lam.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e30a1339085860416b198d1c82903ac29b64429
--- /dev/null
+++ b/lam/runners/train/lam.py
@@ -0,0 +1,869 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import math
+from tqdm.auto import tqdm
+import torch
+import torch.nn as nn
+import torchvision
+import numpy as np
+from torchvision.utils import make_grid
+from einops import rearrange, repeat
+from accelerate.logging import get_logger
+from taming.modules.losses.vqperceptual import hinge_d_loss
+
+from .base_trainer import Trainer
+from lam.utils.profiler import DummyProfiler
+from lam.runners import REGISTRY_RUNNERS
+from lam.utils.hf_hub import wrap_model_hub
+from safetensors.torch import load_file
+from pytorch3d.ops.knn import knn_points
+import torch.nn.functional as F
+
+logger = get_logger(__name__)
+
+# torch.autograd.set_detect_anomaly(True)
+
+
+from omegaconf import OmegaConf
+@REGISTRY_RUNNERS.register('train.lam')
+class LAMTrainer(Trainer):
+
+ EXP_TYPE: str = 'lam'
+
+ def __init__(self):
+ super().__init__()
+
+ self.model = self._build_model(self.cfg)
+ if self.has_disc:
+ self.model_disc = self._build_model_disc(self.cfg)
+ self.optimizer = self._build_optimizer(self.model, self.cfg)
+ if self.has_disc:
+ self.optimizer_disc = self._build_optimizer(self.model_disc, self.cfg)
+
+ self.train_loader, self.val_loader = self._build_dataloader(self.cfg)
+ self.scheduler = self._build_scheduler(self.optimizer, self.cfg)
+ if self.has_disc:
+ self.scheduler_disc = self._build_scheduler(self.optimizer_disc, self.cfg)
+ self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg)
+ self.only_sym_conf = 2
+ print("==="*16*3, "\n"+"only_sym_conf:", self.only_sym_conf, "\n"+"==="*16*3)
+
+
+ def _build_model(self, cfg):
+ assert cfg.experiment.type == 'lrm', \
+ f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}"
+ from lam.models import ModelLAM
+ model = ModelLAM(**cfg.model)
+
+ # resume
+ if len(self.cfg.train.resume) > 0:
+ resume = self.cfg.train.resume
+ print("==="*16*3)
+ self.accelerator.print("loading pretrained weight from:", resume)
+ if resume.endswith('safetensors'):
+ ckpt = load_file(resume, device='cpu')
+ else:
+ ckpt = torch.load(resume, map_location='cpu')
+ state_dict = model.state_dict()
+ for k, v in ckpt.items():
+ if k in state_dict:
+ if state_dict[k].shape == v.shape:
+ state_dict[k].copy_(v)
+ else:
+ self.accelerator.print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.")
+ else:
+ self.accelerator.print(f"WARN] unexpected param {k}: {v.shape}")
+ self.accelerator.print("Finish loading ckpt:", resume, "\n"+"==="*16*3)
+ return model
+
+ def _build_model_disc(self, cfg):
+ if cfg.model.disc.type == "pix2pix":
+ from lam.models.discriminator import NLayerDiscriminator, weights_init
+ model = NLayerDiscriminator(input_nc=cfg.model.disc.in_channels,
+ n_layers=cfg.model.disc.num_layers,
+ use_actnorm=cfg.model.disc.use_actnorm
+ ).apply(weights_init)
+
+ elif cfg.model.disc.type == "vqgan":
+ from lam.models.discriminator import Discriminator
+ model = Discriminator(in_channels=cfg.model.disc.in_channels,
+ cond_channels=0, hidden_channels=512,
+ depth=cfg.model.disc.depth)
+ elif cfg.model.disc.type == "stylegan":
+ from lam.models.gan.stylegan_discriminator import SingleDiscriminatorV2, SingleDiscriminator
+ from lam.models.gan.stylegan_discriminator_torch import Discriminator
+
+ model = Discriminator(512, channel_multiplier=2)
+
+ model.input_size = cfg.model.disc.img_res
+ else:
+ raise NotImplementedError
+ return model
+
+ def _build_optimizer(self, model: nn.Module, cfg):
+ decay_params, no_decay_params = [], []
+
+ # add all bias and LayerNorm params to no_decay_params
+ for name, module in model.named_modules():
+ if isinstance(module, nn.LayerNorm):
+ no_decay_params.extend([p for p in module.parameters()])
+ elif hasattr(module, 'bias') and module.bias is not None:
+ no_decay_params.append(module.bias)
+
+ # add remaining parameters to decay_params
+ _no_decay_ids = set(map(id, no_decay_params))
+ decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids]
+
+ # filter out parameters with no grad
+ decay_params = list(filter(lambda p: p.requires_grad, decay_params))
+ no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
+
+ # monitor this to make sure we don't miss any parameters
+ logger.info("======== Weight Decay Parameters ========")
+ logger.info(f"Total: {len(decay_params)}")
+ logger.info("======== No Weight Decay Parameters ========")
+ logger.info(f"Total: {len(no_decay_params)}")
+
+ # Optimizer
+ opt_groups = [
+ {'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay},
+ {'params': no_decay_params, 'weight_decay': 0.0},
+ ]
+ optimizer = torch.optim.AdamW(
+ opt_groups,
+ lr=cfg.train.optim.lr,
+ betas=(cfg.train.optim.beta1, cfg.train.optim.beta2),
+ )
+
+ return optimizer
+
+ def _build_scheduler(self, optimizer, cfg):
+ local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes)
+ total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps)
+ effective_warmup_iters = cfg.train.scheduler.warmup_real_iters
+ logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========")
+ logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========")
+ if cfg.train.scheduler.type == 'cosine':
+ from lam.utils.scheduler import CosineWarmupScheduler
+ scheduler = CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_iters=effective_warmup_iters,
+ max_iters=total_global_batches,
+ )
+ else:
+ raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented")
+ return scheduler
+
+ def _build_dataloader(self, cfg):
+ # dataset class
+ from lam.datasets import MixerDataset
+ gaga_track_type = cfg.dataset.get("gaga_track_type", "vfhq_gagtrack")
+ sample_aug_views = cfg.dataset.get("sample_aug_views", 0)
+
+ # build dataset
+ load_normal = cfg.train.loss.get("normal_weight", False) > 0. if hasattr(cfg.train.loss, "normal_weight") else False
+ load_normal = load_normal or (cfg.train.loss.get("surfel_normal_weight", False) > 0. if hasattr(cfg.train.loss, "surfel_normal_weight") else False)
+ print("==="*16*3, "\nload_normal:", load_normal)
+ train_dataset = MixerDataset(
+ split="train",
+ subsets=cfg.dataset.subsets,
+ sample_side_views=cfg.dataset.sample_side_views,
+ render_image_res_low=cfg.dataset.render_image.low,
+ render_image_res_high=cfg.dataset.render_image.high,
+ render_region_size=cfg.dataset.render_image.region,
+ source_image_res=cfg.dataset.source_image_res,
+ repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1,
+ multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14,
+ debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False,
+ is_val=False,
+ gaga_track_type=gaga_track_type,
+ sample_aug_views=sample_aug_views,
+ load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False,
+ load_normal=load_normal,
+ )
+ val_dataset = MixerDataset(
+ split="val",
+ subsets=cfg.dataset.subsets,
+ sample_side_views=cfg.dataset.sample_side_views,
+ render_image_res_low=cfg.dataset.render_image.low,
+ render_image_res_high=cfg.dataset.render_image.high,
+ render_region_size=cfg.dataset.render_image.region,
+ source_image_res=cfg.dataset.source_image_res,
+ repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1,
+ multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14,
+ debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False,
+ is_val=True,
+ gaga_track_type=gaga_track_type,
+ sample_aug_views=sample_aug_views,
+ load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False,
+ load_normal=load_normal,
+ )
+
+ # build data loader
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=cfg.train.batch_size,
+ shuffle=True,
+ drop_last=True,
+ num_workers=cfg.dataset.num_train_workers,
+ pin_memory=cfg.dataset.pin_mem,
+ persistent_workers=True,
+ )
+ val_loader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=cfg.val.batch_size,
+ shuffle=False,
+ drop_last=False,
+ num_workers=cfg.dataset.num_val_workers,
+ pin_memory=cfg.dataset.pin_mem,
+ persistent_workers=False,
+ )
+
+ return train_loader, val_loader
+
+ def _build_loss_fn(self, cfg):
+ from lam.losses import PixelLoss, LPIPSLoss, TVLoss
+ pixel_loss_fn = PixelLoss(option=cfg.train.loss.get("pixel_loss_fn", "mse"))
+ with self.accelerator.main_process_first():
+ perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True)
+
+ if cfg.model.get("use_conf_map", False):
+ assert cfg.train.loss.get("head_pl", False), "Set head_pl in train.loss to true to use faceperceptualloss when using conf_map."
+ tv_loss_fn = TVLoss()
+ return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn
+
+ def register_hooks(self):
+ pass
+
+ def get_flame_params(self, data, is_source=False):
+ flame_params = {}
+ flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\
+ 'rotation', 'neck_pose', 'eyes_pose', 'translation', "teeth_bs"]
+ if is_source:
+ flame_keys = ['source_'+item for item in flame_keys]
+ for k, v in data.items():
+ if k in flame_keys:
+ # print(k, v.shape)
+ flame_params[k] = data[k]
+ return flame_params
+
+ def cross_copy(self, data):
+ B = data.shape[0]
+ assert data.shape[1] == 1
+ new_data = []
+ for i in range(B):
+ B_i = [data[i]]
+ for j in range(B):
+ if j != i:
+ B_i.append(data[j])
+ new_data.append(torch.concat(B_i, dim=0))
+ new_data = torch.stack(new_data, dim=0)
+
+ return new_data
+
+ def prepare_cross_render_data(self, data):
+ B, N_v, C, H, W = data['render_image'].shape
+ assert N_v == 1
+
+ # cross copy
+ data["c2ws"] = self.cross_copy(data["c2ws"])
+ data["intrs"] = self.cross_copy(data["intrs"])
+ data["render_full_resolutions"] = self.cross_copy(data["render_full_resolutions"])
+ data["render_image"] = self.cross_copy(data["render_image"])
+ data["render_mask"] = self.cross_copy(data["render_mask"])
+ data["render_bg_colors"] = self.cross_copy(data["render_bg_colors"])
+ flame_params = self.get_flame_params(data)
+ for key in flame_params.keys():
+ if "betas" not in key:
+ data[key] = self.cross_copy(data[key])
+ source_flame_params = self.get_flame_params(data, is_source=True)
+ for key in source_flame_params.keys():
+ if "betas" not in key:
+ data[key] = self.cross_copy(data[key])
+
+ return data
+
+ def get_loss_weight(self, loss_weight):
+ if isinstance(loss_weight, str) and ":" in loss_weight:
+ start_step, start_value, end_value, end_step = map(float, loss_weight.split(":"))
+ current_step = self.global_step
+ value = start_value + (end_value - start_value) * max(
+ min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
+ )
+ return value
+ elif isinstance(loss_weight, (float, int)):
+ return loss_weight
+ else:
+ raise NotImplementedError
+
+ def forward_loss_local_step(self, data):
+ render_image = data['render_image']
+ render_albedo = data.get('render_albedo', None)
+ render_mask = data['render_mask']
+ render_normal = data.get('render_normal', None)
+ B, N_v, C, H, W = render_image.shape
+ flame_params = self.get_flame_params(data)
+ source_flame_params = self.get_flame_params(data, is_source=True)
+
+ # forward
+ outputs = self.model(
+ image=data['source_rgbs'],
+ source_c2ws=data['source_c2ws'],
+ source_intrs=data['source_intrs'],
+ render_c2ws=data['c2ws'],
+ render_intrs=data['intrs'],
+ render_bg_colors=data['render_bg_colors'],
+ flame_params=flame_params,
+ source_flame_params=source_flame_params,
+ render_images=render_image,
+ data = data
+ )
+
+ # loss calculation
+ loss = 0.
+ loss_pixel = None
+ loss_perceptual = None
+ loss_mask = None
+ extra_loss_dict = {}
+
+ num_aug_view = self.cfg.dataset.get("sample_aug_views", 0)
+ real_num_view = data["real_num_view"] - num_aug_view
+
+ conf_sigma_l1 = outputs.get("conf_sigma_l1", None)
+ conf_sigma_percl = outputs.get("conf_sigma_percl", None)
+ if self.cfg.model.use_sym_proj:
+ real_num_view *= 2
+ if self.cfg.model.use_conf_map:
+ conf_sigma_l1 = rearrange(conf_sigma_l1, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view]
+ conf_sigma_percl = rearrange(conf_sigma_percl, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view]
+ render_image = repeat(data['render_image'], "b v c h w -> b (v r) c h w", r=2)
+ render_albedo = repeat(render_albedo, "b v c h w -> b (v r) c h w", r=2) if render_albedo is not None else None
+ render_mask = repeat(data['render_mask'], "b v c h w -> b (v r) c h w", r=2)
+ if "render_normal" in data.keys():
+ render_normal = repeat(data['render_normal'], "b v c h w -> b (v r) c h w", r=2)
+ for k, v in data.items():
+ if "bbox" in k:
+ data[k] = repeat(v, "b v c -> b (v r) c", r=2)
+
+ only_sym_conf = self.only_sym_conf
+
+ if self.get_loss_weight(self.cfg.train.loss.get("masked_pixel_weight", 0)) > 0.:
+ gt_rgb = render_image[:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view])
+ pred_rgb = outputs['comp_rgb'][:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view])
+
+ loss_pixel = self.pixel_loss_fn(pred_rgb, gt_rgb, conf_sigma_l1, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight)
+ loss += loss_pixel
+
+ # using same weight
+ loss_perceptual = self.perceptual_loss_fn(pred_rgb, gt_rgb, conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight)
+ loss += loss_perceptual
+
+ if self.get_loss_weight(self.cfg.train.loss.pixel_weight) > 0.:
+ total_loss_pixel = loss_pixel
+ if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"):
+ loss_pixel = self.pixel_loss_fn(
+ outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf
+ ) * self.get_loss_weight(self.cfg.train.loss.pixel_weight)
+ loss += loss_pixel
+ if total_loss_pixel is not None:
+ loss_pixel += total_loss_pixel
+
+ if self.get_loss_weight(self.cfg.train.loss.perceptual_weight) > 0.:
+ total_loss_perceptual = loss_perceptual
+ if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"):
+ loss_perceptual = self.perceptual_loss_fn(
+ outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf
+ ) * self.get_loss_weight(self.cfg.train.loss.perceptual_weight)
+ loss += loss_perceptual
+ if total_loss_perceptual is not None:
+ loss_perceptual += total_loss_perceptual
+
+ if self.get_loss_weight(self.cfg.train.loss.mask_weight) > 0. and 'comp_mask' in outputs.keys():
+ loss_mask = self.pixel_loss_fn(outputs['comp_mask'][:, :real_num_view], render_mask[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf
+ ) * self.get_loss_weight(self.cfg.train.loss.mask_weight)
+ loss += loss_mask
+
+ if hasattr(self.cfg.train.loss, 'offset_reg_weight') and self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) > 0.:
+ loss_offset_reg = 0
+ for b_idx in range(len(outputs['3dgs'])):
+ loss_offset_reg += torch.nn.functional.mse_loss(outputs['3dgs'][b_idx][0].offset.float(), torch.zeros_like(outputs['3dgs'][b_idx][0].offset.float()))
+ loss_offset_reg = loss_offset_reg / len(outputs['3dgs'])
+ loss += loss_offset_reg * self.get_loss_weight(self.cfg.train.loss.offset_reg_weight)
+ else:
+ loss_offset_reg = None
+
+ return outputs, loss, loss_pixel, loss_perceptual, loss_offset_reg, loss_mask, extra_loss_dict
+
+ def adopt_weight(self, weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer, discriminator_weight=1):
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * discriminator_weight
+ return d_weight
+
+ def disc_preprocess(self, img):
+ # reshape [B, N_v, C, H, W] to [B*N_v, C, H, W]
+ img = torch.flatten(img, 0, 1)
+ # img = rearrange(img, 'b n c h w -> (b n) c h w')
+ # convert 0-1 to -1-1
+ img = 2 * img - 1
+
+ if hasattr(self.accelerator.unwrap_model(self.model_disc), "input_size"):
+ tgt_size = self.accelerator.unwrap_model(self.model_disc).input_size
+ img = nn.functional.interpolate(img, (tgt_size, tgt_size))
+ img = img.float()
+
+ return img
+
+ def forward_to_get_loss_with_gen_loss(self, data):
+ # forward to loss
+ outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
+
+ with torch.autocast(device_type=outs["comp_rgb"].device.type, dtype=torch.float32):
+ logits_fake = self.model_disc(self.disc_preprocess(outs["comp_rgb"]))
+
+ loss_gen = -torch.mean(logits_fake)
+
+ try:
+ if loss < 1e-5:
+ d_weight = self.cfg.model.disc.disc_weight
+ else:
+ nll_loss = loss_pixel
+ if nll_loss is None:
+ nll_loss = loss
+ d_weight = self.calculate_adaptive_weight(nll_loss, loss_gen,
+ last_layer=self.accelerator.unwrap_model(self.model).get_last_layer(),
+ discriminator_weight=self.cfg.model.disc.disc_weight)
+ except RuntimeError:
+ print("*************Error when calculate_adaptive_weight************")
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start)
+ # print(disc_factor, d_weight)
+
+ loss += disc_factor * d_weight * loss_gen
+
+ # backward
+ self.accelerator.backward(loss)
+ if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict
+
+
+ def forward_to_get_loss(self, data):
+ # forward to loss
+ outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
+
+ # backward
+ self.accelerator.backward(loss)
+ if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict
+
+
+ def forward_disc_loss_local_step(self, pred_img, gt_img):
+ # detach gradient of pred_img
+ with torch.autocast(device_type=pred_img.device.type, dtype=torch.float32):
+ logits_real = self.model_disc(self.disc_preprocess(gt_img).detach())
+ logits_fake = self.model_disc(self.disc_preprocess(pred_img).detach())
+
+ loss_disc = hinge_d_loss(logits_real, logits_fake)
+ return loss_disc
+
+
+ def forward_to_get_disc_loss(self, pred_img, gt_img):
+ # forward to loss
+ loss_disc = self.forward_disc_loss_local_step(pred_img, gt_img)
+
+ disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start)
+ loss = disc_factor * loss_disc
+
+ # backward
+ self.accelerator.backward(loss)
+
+ if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
+ self.accelerator.clip_grad_norm_(self.model_disc.parameters(), self.cfg.train.optim.clip_grad_norm)
+
+ self.optimizer_disc.step()
+ self.optimizer_disc.zero_grad()
+
+ return loss_disc
+
+ def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile, iepoch: int):
+
+ self.model.train()
+ if self.has_disc:
+ self.model_disc.train()
+
+ local_step_losses = []
+ global_step_losses = []
+ local_step_extra_losses = []
+ global_step_extra_losses = []
+ extra_loss_keys = []
+
+ logger.debug(f"======== Starting epoch {self.current_epoch} ========")
+ loss_disc = None
+ for idx, data in enumerate(loader):
+ data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype)
+ if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render:
+ data = self.prepare_cross_render_data(data)
+ data["real_num_view"] = 1
+ else:
+ data["real_num_view"] = data["render_image"].shape[1]
+
+ logger.debug(f"======== Starting global step {self.global_step} ========")
+
+ if not self.has_disc:
+ disc_step = False
+ with self.accelerator.accumulate(self.model):
+ outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_to_get_loss(data)
+
+ # track local losses
+ loss_disc, loss_gen = None, None
+ local_step_losses.append(torch.stack([
+ _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
+ for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_disc, loss_gen]
+ ]))
+ extra_loss_keys = sorted(list(extra_loss_dict.keys()))
+ if len(extra_loss_keys) > 0:
+ local_step_extra_losses.append(torch.stack([
+ extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
+ for k in extra_loss_keys
+ ]))
+ else:
+ disc_step = (idx % 5) == 0 or (iepoch * len(loader) + idx < 100 and idx % 2 == 0)
+ local_step_losses_bak = torch.zeros(6, device=data["source_rgbs"].device)
+ if not disc_step:
+ with self.accelerator.accumulate(self.model):
+ # generator step
+ outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict = self.forward_to_get_loss_with_gen_loss(data)
+ # track local losses
+ local_step_losses.append(torch.stack([
+ _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
+ for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc]
+ ]))
+ local_step_losses_bak = local_step_losses[-1].detach()
+ torch.cuda.empty_cache()
+ extra_loss_keys = sorted(list(extra_loss_dict.keys()))
+ if len(extra_loss_keys) > 0:
+ local_step_extra_losses.append(torch.stack([
+ extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
+ for k in extra_loss_keys
+ ]))
+ else:
+ with self.accelerator.accumulate(self.model_disc):
+ # discriminator step
+ outs, _, _, _, _, _, _ = self.forward_loss_local_step(data)
+ loss_disc = self.forward_to_get_disc_loss(pred_img=outs["comp_rgb"],
+ gt_img=data["render_image"])
+ local_step_losses.append(torch.concat([local_step_losses_bak[:6], loss_disc.unsqueeze(0)], dim=0))
+ torch.cuda.empty_cache()
+
+ # track global step
+ if self.accelerator.sync_gradients:
+ profiler.step()
+ if not disc_step:
+ self.scheduler.step()
+ if self.has_disc and disc_step:
+ self.scheduler_disc.step()
+ logger.debug(f"======== Scheduler step ========")
+ self.global_step += 1
+ global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu()
+ if len(extra_loss_keys) > 0:
+ global_step_extra_loss = self.accelerator.gather(torch.stack(local_step_extra_losses)).mean(dim=0).cpu()
+ global_step_extra_loss_items = global_step_extra_loss.unbind()
+ else:
+ global_step_extra_loss = None
+ global_step_extra_loss_items = []
+ loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc_ = global_step_loss.unbind()
+ loss_kwargs = {
+ 'loss': loss.item(),
+ 'loss_pixel': loss_pixel.item(),
+ 'loss_perceptual': loss_perceptual.item(),
+ 'loss_tv': loss_tv.item(),
+ 'loss_mask': loss_mask.item(),
+ 'loss_disc': loss_disc_.item(),
+ 'loss_gen': loss_gen.item(),
+ }
+ for k, loss in zip(extra_loss_keys, global_step_extra_loss_items):
+ loss_kwargs[k] = loss.item()
+ self.log_scalar_kwargs(
+ step=self.global_step, split='train',
+ **loss_kwargs
+ )
+ self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1])
+ local_step_losses = []
+ global_step_losses.append(global_step_loss)
+ local_step_extra_losses = []
+ global_step_extra_losses.append(global_step_extra_loss)
+
+ # manage display
+ pbar.update(1)
+ description = {
+ **loss_kwargs,
+ 'lr': self.optimizer.param_groups[0]['lr'],
+ }
+ description = '[TRAIN STEP]' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v))
+ pbar.set_description(description)
+
+ # periodic actions
+ if self.global_step % self.cfg.saver.checkpoint_global_steps == 0:
+ self.save_checkpoint()
+ if self.global_step % self.cfg.val.global_step_period == 0:
+ self.evaluate()
+ self.model.train()
+ if self.has_disc:
+ self.model_disc.train()
+ if (self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0) or (self.global_step < 1000 and self.global_step % 20 == 0):
+ conf_sigma_l1 = outs.get('conf_sigma_l1', None)
+ conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None
+ conf_sigma_percl = outs.get('conf_sigma_percl', None)
+ conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None
+ self.log_image_monitor(
+ step=self.global_step, split='train',
+ renders=outs['comp_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl,
+ gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ )
+ if 'comp_mask' in outs.keys():
+ self.log_image_monitor(
+ step=self.global_step, split='train',
+ renders=outs['comp_mask'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ prefix="_mask",
+ )
+
+ # progress control
+ if self.global_step >= self.N_max_global_steps:
+ self.accelerator.set_trigger()
+ break
+
+ # track epoch
+ self.current_epoch += 1
+ epoch_losses = torch.stack(global_step_losses).mean(dim=0)
+ epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv, epoch_loss_mask, epoch_loss_disc, epoch_loss_gen = epoch_losses.unbind()
+ epoch_loss_dict = {
+ 'loss': epoch_loss.item(),
+ 'loss_pixel': epoch_loss_pixel.item(),
+ 'loss_perceptual': epoch_loss_perceptual.item(),
+ 'loss_tv': epoch_loss_tv.item(),
+ 'loss_mask': epoch_loss_mask.item(),
+ 'loss_disc': epoch_loss_disc.item(),
+ 'loss_gen': epoch_loss_gen.item(),
+ }
+ if len(extra_loss_keys) > 0:
+ epoch_extra_losses = torch.stack(global_step_extra_losses).mean(dim=0)
+ for k, v in zip(extra_loss_keys, epoch_extra_losses.unbind()):
+ epoch_loss_dict[k] = v.item()
+ self.log_scalar_kwargs(
+ epoch=self.current_epoch, split='train',
+ **epoch_loss_dict,
+ )
+ logger.info(
+ f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v))
+ )
+
+ def train(self):
+
+ starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps
+ skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch)
+ logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========")
+
+ with tqdm(
+ range(0, self.N_max_global_steps),
+ initial=self.global_step,
+ disable=(not self.accelerator.is_main_process),
+ ) as pbar:
+
+ profiler = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
+ schedule=torch.profiler.schedule(
+ wait=10, warmup=10, active=100,
+ ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(
+ self.cfg.logger.tracker_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ )),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) if self.cfg.logger.enable_profiler else DummyProfiler()
+
+ with profiler:
+ self.optimizer.zero_grad()
+ if self.has_disc:
+ self.optimizer_disc.zero_grad()
+ for iepoch in range(self.current_epoch, self.cfg.train.epochs):
+
+ loader = skipped_loader or self.train_loader
+ skipped_loader = None
+ self.train_epoch(pbar=pbar, loader=loader, profiler=profiler, iepoch=iepoch)
+ if self.accelerator.check_trigger():
+ break
+
+ logger.info(f"======== Training finished at global step {self.global_step} ========")
+
+ # final checkpoint and evaluation
+ self.save_checkpoint()
+ self.evaluate()
+
+ @torch.no_grad()
+ @torch.compiler.disable
+ def evaluate(self, epoch: int = None):
+ self.model.eval()
+
+ max_val_batches = self.cfg.val.debug_batches or len(self.val_loader)
+ running_losses = []
+ running_extra_losses = []
+ extra_loss_keys = []
+ sample_data, sample_outs = None, None
+
+ for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches):
+ data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype)
+ if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render:
+ data = self.prepare_cross_render_data(data)
+ data["real_num_view"] = 1
+ else:
+ data["real_num_view"] = data["render_image"].shape[1]
+
+ if len(running_losses) >= max_val_batches:
+ logger.info(f"======== Early stop validation at {len(running_losses)} batches ========")
+ break
+
+ outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
+ extra_loss_dict = sorted(list(extra_loss_dict.keys()))
+ sample_data, sample_outs = data, outs
+
+ running_losses.append(torch.stack([
+ _loss if _loss is not None else torch.tensor(float('nan'), device=self.device)
+ for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask]
+ ]))
+ if len(extra_loss_keys) > 0:
+ running_extra_losses.append(torch.stack([
+ extra_loss_dict[k] if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
+ for k in extra_loss_keys
+ ]))
+
+ # log each step
+ conf_sigma_l1 = sample_outs.get('conf_sigma_l1', None)
+ conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None
+ conf_sigma_percl = sample_outs.get('conf_sigma_percl', None)
+ conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None
+ self.log_image_monitor_each_process(
+ step=self.global_step, split='val',
+ renders=sample_outs['comp_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl,
+ prefix=f"_{len(running_losses)}_rank{self.accelerator.process_index}"
+ )
+ if "comp_mask" in sample_outs.keys():
+ self.log_image_monitor_each_process(
+ step=self.global_step, split='val',
+ renders=sample_outs['comp_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=sample_data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ prefix=f"_mask_{len(running_losses)}_rank{self.accelerator.process_index}"
+ )
+
+ total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu()
+ total_loss, total_loss_pixel, total_loss_perceptual, total_loss_offset, total_loss_mask = total_losses.unbind()
+ total_loss_dict = {
+ 'loss': total_loss.item(),
+ 'loss_pixel': total_loss_pixel.item(),
+ 'loss_perceptual': total_loss_perceptual.item(),
+ 'loss_offset': total_loss_offset.item(),
+ 'loss_mask': total_loss_mask.item(),
+ }
+ if len(extra_loss_keys) > 0:
+ total_extra_losses = self.accelerator.gather(torch.stack(running_extra_losses)).mean(dim=0).cpu()
+ for k, v in zip(extra_loss_keys, total_extra_losses.unbind()):
+ total_loss_dict[k] = v.item()
+
+ if epoch is not None:
+ self.log_scalar_kwargs(
+ epoch=epoch, split='val',
+ **total_loss_dict,
+ )
+ logger.info(
+ f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
+ )
+ else:
+ self.log_scalar_kwargs(
+ step=self.global_step, split='val',
+ **total_loss_dict,
+ )
+ logger.info(
+ f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
+ )
+
+ def log_image_monitor_each_process(
+ self, epoch: int = None, step: int = None, split: str = None,
+ renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None,
+ conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None
+ ):
+ M = renders.shape[1]
+ if gts.shape[1] != M:
+ gts = repeat(gts, "b v c h w -> b (v r) c h w", r=2)
+ merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:])
+ renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:])
+ renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M)
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ split = f'/{split}' if split else ''
+ split = split + prefix if prefix is not None else split
+ log_img_dict = {
+ f'Images_split{split}/rendered': renders.unsqueeze(0),
+ f'Images_split{split}/gt': gts.unsqueeze(0),
+ f'Images_split{split}/merged': merged.unsqueeze(0),
+ }
+ if conf_sigma_l1 is not None:
+ EPS = 1e-7
+ vis_conf_l1 = 1/(1+conf_sigma_l1.detach()+EPS).cpu()
+ vis_conf_percl = 1/(1+conf_sigma_percl.detach()+EPS).cpu()
+ vis_conf_l1, vis_conf_percl = rearrange(vis_conf_l1, "b v (r c) h w -> (b v r) c h w", r=2), rearrange(vis_conf_percl, "b v (r c) h w -> (b v r) c h w", r=2)
+ vis_conf_l1, vis_conf_percl = repeat(vis_conf_l1, "b c1 h w-> b (c1 c2) h w", c2=3), repeat(vis_conf_percl, "b c1 h w -> b (c1 c2) h w", c2=3)
+ vis_conf_l1, vis_conf_percl = make_grid(vis_conf_l1, nrow=M), make_grid(vis_conf_percl, nrow=M)
+ log_img_dict[f'Images_split{split}/conf_l1'] = vis_conf_l1.unsqueeze(0)
+ log_img_dict[f'Images_split{split}/conf_percl'] = vis_conf_percl.unsqueeze(0)
+
+ self.log_images_each_process(log_img_dict, log_progress, {"imwrite_image": False})
+
+
+ @Trainer.control('on_main_process')
+ def log_image_monitor(
+ self, epoch: int = None, step: int = None, split: str = None,
+ renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None,
+ conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None
+ ):
+ self.log_image_monitor_each_process(epoch, step, split, renders, gts, prefix, conf_sigma_l1, conf_sigma_percl)
diff --git a/lam/utils/__init__.py b/lam/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/lam/utils/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/lam/utils/compile.py b/lam/utils/compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..08972a23daf1c046c327ce93fc667b706a3ec65b
--- /dev/null
+++ b/lam/utils/compile.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def configure_dynamo(config: dict):
+ try:
+ import torch._dynamo
+ logger.debug(f'Configuring torch._dynamo.config with {config}')
+ for k, v in config.items():
+ if v is None:
+ logger.debug(f'Skipping torch._dynamo.config.{k} with None')
+ continue
+ if hasattr(torch._dynamo.config, k):
+ logger.warning(f'Overriding torch._dynamo.config.{k} from {getattr(torch._dynamo.config, k)} to {v}')
+ setattr(torch._dynamo.config, k, v)
+ except ImportError:
+ logger.debug('torch._dynamo not found, skipping')
+ pass
diff --git a/lam/utils/ffmpeg_utils.py b/lam/utils/ffmpeg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6c4ec6575789be98d261ebd35f3683ef0d8881e
--- /dev/null
+++ b/lam/utils/ffmpeg_utils.py
@@ -0,0 +1,64 @@
+import os
+import pdb
+import torch
+import numpy as np
+import imageio
+import cv2
+import imageio.v3 as iio
+
+VIDEO_TYPE_LIST = {'.avi','.mp4','.gif','.AVI','.MP4','.GIF'}
+
+def encodeffmpeg(inputs, frame_rate, output, format="png"):
+ """output: need video_name"""
+ assert (
+ os.path.splitext(output)[-1] in VIDEO_TYPE_LIST
+ ), "output is the format of video, e.g., mp4"
+ assert os.path.isdir(inputs), "input dir is NOT file format"
+
+ inputs = inputs[:-1] if inputs[-1] == "/" else inputs
+
+ output = os.path.abspath(output)
+
+ cmd = (
+ f"ffmpeg -r {frame_rate} -pattern_type glob -i '{inputs}/*.{format}' "
+ + f'-vcodec libx264 -crf 10 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" '
+ + f"-pix_fmt yuv420p {output} > /dev/null 2>&1"
+ )
+
+ print(cmd)
+
+ output_dir = os.path.dirname(output)
+ if os.path.exists(output):
+ os.remove(output)
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("encoding imgs to video.....")
+ os.system(cmd)
+ print("video done!")
+
+def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False, bitrate="2M"):
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ frames = []
+ for i in range(images.shape[0]):
+ if isinstance(images, torch.Tensor):
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
+ assert frame.min() >= 0 and frame.max() <= 255, \
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
+ else:
+ frame = images[i]
+ width, height = frame.shape[1], frame.shape[0]
+ # reshape to limit the export time
+ # if width > 1200 or height > 1200 or images.shape[0] > 200:
+ # frames.append(cv2.resize(frame, (width // 2, height // 2)))
+ # else:
+ frames.append(frame)
+ # limit the frames directly @NOTE huggingface only!
+ frames = frames[:200]
+
+ frames = np.stack(frames)
+
+ print("start saving {} using imageio.v3 .".format(output_path))
+ iio.imwrite(output_path,frames,fps=fps,codec="libx264",pixelformat="yuv420p",bitrate=bitrate, macro_block_size=32)
+ print("saved {} using imageio.v3 .".format(output_path))
\ No newline at end of file
diff --git a/lam/utils/gen_id_json.py b/lam/utils/gen_id_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..270240a3b15e15ffc2b2d684a40aacdf630dcdc7
--- /dev/null
+++ b/lam/utils/gen_id_json.py
@@ -0,0 +1,18 @@
+import json
+import glob
+import sys
+import os
+
+data_root = sys.argv[1]
+save_path = sys.argv[2]
+
+all_hid_list = []
+for hid in os.listdir(data_root):
+ if hid.startswith("p"):
+ hid = os.path.join(data_root, hid)
+ all_hid_list.append(hid.replace(data_root + "/", ""))
+
+print(f"len:{len(all_hid_list)}")
+print(all_hid_list[:3])
+with open(save_path, 'w') as fp:
+ json.dump(all_hid_list, fp, indent=4)
\ No newline at end of file
diff --git a/lam/utils/gen_json.py b/lam/utils/gen_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..768ea84d0791e5e960c77aa7f9b55e9096fbe3d4
--- /dev/null
+++ b/lam/utils/gen_json.py
@@ -0,0 +1,23 @@
+import json
+import glob
+import sys
+import os
+
+data_root = sys.argv[1]
+save_path = sys.argv[2]
+
+all_img_list = []
+for hid in os.listdir(data_root):
+ all_view_imgs_dir = os.path.join(data_root, hid, "kinect_color")
+ if not os.path.exists(all_view_imgs_dir):
+ continue
+
+ for view_id in os.listdir(all_view_imgs_dir):
+ imgs_dir = os.path.join(all_view_imgs_dir, view_id)
+ for img_path in glob.glob(os.path.join(imgs_dir, "*.png")):
+ all_img_list.append(img_path.replace(data_root + "/", ""))
+
+print(f"len:{len(all_img_list)}")
+print(all_img_list[:3])
+with open(save_path, 'w') as fp:
+ json.dump(all_img_list, fp, indent=4)
\ No newline at end of file
diff --git a/lam/utils/hf_hub.py b/lam/utils/hf_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9ba0df56983a407d20c2c656a82c1ad15487ca5
--- /dev/null
+++ b/lam/utils/hf_hub.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin
+
+
+def wrap_model_hub(model_cls: nn.Module):
+ class HfModel(model_cls, PyTorchModelHubMixin):
+ def __init__(self, config: dict):
+ super().__init__(**config)
+ self.config = config
+ return HfModel
diff --git a/lam/utils/logging.py b/lam/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e2ecd77ff0d1dc9b7fa5cb4efc6edcda8a18d0d
--- /dev/null
+++ b/lam/utils/logging.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import logging
+from tqdm.auto import tqdm
+
+
+class TqdmStreamHandler(logging.StreamHandler):
+ def emit(self, record):
+ tqdm.write(self.format(record))
+
+
+def configure_logger(stream_level, log_level, file_path = None):
+ _stream_level = stream_level.upper()
+ _log_level = log_level.upper()
+ _project_level = _log_level
+
+ _formatter = logging.Formatter("[%(asctime)s] %(name)s: [%(levelname)s] %(message)s")
+
+ _stream_handler = TqdmStreamHandler()
+ _stream_handler.setLevel(_stream_level)
+ _stream_handler.setFormatter(_formatter)
+
+ if file_path is not None:
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ _file_handler = logging.FileHandler(file_path)
+ _file_handler.setLevel(_log_level)
+ _file_handler.setFormatter(_formatter)
+
+ _project_logger = logging.getLogger(__name__.split('.')[0])
+ _project_logger.setLevel(_project_level)
+ _project_logger.addHandler(_stream_handler)
+ if file_path is not None:
+ _project_logger.addHandler(_file_handler)
diff --git a/lam/utils/preprocess.py b/lam/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4724a4c5ed6cba9e16dac265bbbaf105a0b57dd6
--- /dev/null
+++ b/lam/utils/preprocess.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import rembg
+import cv2
+
+
+class Preprocessor:
+
+ """
+ Preprocessing under cv2 conventions.
+ """
+
+ def __init__(self):
+ self.rembg_session = rembg.new_session(
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
+ )
+
+ def preprocess(self, image_path: str, save_path: str, rmbg: bool = True, recenter: bool = True, size: int = 512, border_ratio: float = 0.2):
+ image = self.step_load_to_size(image_path=image_path, size=size*2)
+ if rmbg:
+ image = self.step_rembg(image_in=image)
+ else:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
+ if recenter:
+ image = self.step_recenter(image_in=image, border_ratio=border_ratio, square_size=size)
+ else:
+ image = cv2.resize(
+ src=image,
+ dsize=(size, size),
+ interpolation=cv2.INTER_AREA,
+ )
+ return cv2.imwrite(save_path, image)
+
+ def step_rembg(self, image_in: np.ndarray) -> np.ndarray:
+ image_out = rembg.remove(
+ data=image_in,
+ session=self.rembg_session,
+ )
+ return image_out
+
+ def step_recenter(self, image_in: np.ndarray, border_ratio: float, square_size: int) -> np.ndarray:
+ assert image_in.shape[-1] == 4, "Image to recenter must be RGBA"
+ mask = image_in[..., -1] > 0
+ ijs = np.nonzero(mask)
+ # find bbox
+ i_min, i_max = ijs[0].min(), ijs[0].max()
+ j_min, j_max = ijs[1].min(), ijs[1].max()
+ bbox_height, bbox_width = i_max - i_min, j_max - j_min
+ # recenter and resize
+ desired_size = int(square_size * (1 - border_ratio))
+ scale = desired_size / max(bbox_height, bbox_width)
+ desired_height, desired_width = int(bbox_height * scale), int(bbox_width * scale)
+ desired_i_min, desired_j_min = (square_size - desired_height) // 2, (square_size - desired_width) // 2
+ desired_i_max, desired_j_max = desired_i_min + desired_height, desired_j_min + desired_width
+ # create new image
+ image_out = np.zeros((square_size, square_size, 4), dtype=np.uint8)
+ image_out[desired_i_min:desired_i_max, desired_j_min:desired_j_max] = cv2.resize(
+ src=image_in[i_min:i_max, j_min:j_max],
+ dsize=(desired_width, desired_height),
+ interpolation=cv2.INTER_AREA,
+ )
+ return image_out
+
+ def step_load_to_size(self, image_path: str, size: int) -> np.ndarray:
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
+ height, width = image.shape[:2]
+ scale = size / max(height, width)
+ height, width = int(height * scale), int(width * scale)
+ image_out = cv2.resize(
+ src=image,
+ dsize=(width, height),
+ interpolation=cv2.INTER_AREA,
+ )
+ return image_out
diff --git a/lam/utils/profiler.py b/lam/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ba79973308b627d5b20bdd7bb09eac138c93ad
--- /dev/null
+++ b/lam/utils/profiler.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from torch.profiler import profile
+
+
+class DummyProfiler(profile):
+ def __init__(self):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ pass
+
+ def step(self):
+ pass
diff --git a/lam/utils/proxy.py b/lam/utils/proxy.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddfbc642fbe489a3ede3cc208e70867a81c8f912
--- /dev/null
+++ b/lam/utils/proxy.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+NO_PROXY = "lam_NO_DATA_PROXY" in os.environ
+
+def no_proxy(func):
+ """Decorator to disable proxy but then restore after the function call."""
+ def wrapper(*args, **kwargs):
+ # http_proxy, https_proxy, HTTP_PROXY, HTTPS_PROXY, all_proxy
+ http_proxy = os.environ.get('http_proxy')
+ https_proxy = os.environ.get('https_proxy')
+ HTTP_PROXY = os.environ.get('HTTP_PROXY')
+ HTTPS_PROXY = os.environ.get('HTTPS_PROXY')
+ all_proxy = os.environ.get('all_proxy')
+ os.environ['http_proxy'] = ''
+ os.environ['https_proxy'] = ''
+ os.environ['HTTP_PROXY'] = ''
+ os.environ['HTTPS_PROXY'] = ''
+ os.environ['all_proxy'] = ''
+ try:
+ return func(*args, **kwargs)
+ finally:
+ os.environ['http_proxy'] = http_proxy
+ os.environ['https_proxy'] = https_proxy
+ os.environ['HTTP_PROXY'] = HTTP_PROXY
+ os.environ['HTTPS_PROXY'] = HTTPS_PROXY
+ os.environ['all_proxy'] = all_proxy
+ if NO_PROXY:
+ return wrapper
+ else:
+ return func
diff --git a/lam/utils/registry.py b/lam/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..421a735f82899c50884cd5b5a27e71757b2eb813
--- /dev/null
+++ b/lam/utils/registry.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class Registry:
+ """Registry class"""
+
+ def __init__(self):
+ self._registry = {}
+
+ def register(self, name):
+ """Register a module"""
+ def decorator(cls):
+ assert name not in self._registry, 'Module {} already registered'.format(name)
+ self._registry[name] = cls
+ return cls
+ return decorator
+
+ def __getitem__(self, name):
+ """Get a module"""
+ return self._registry[name]
+
+ def __contains__(self, name):
+ return name in self._registry
diff --git a/lam/utils/scheduler.py b/lam/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc151d816e2787f37f9bea02b0945e06a933c01
--- /dev/null
+++ b/lam/utils/scheduler.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from torch.optim.lr_scheduler import LRScheduler
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class CosineWarmupScheduler(LRScheduler):
+ def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
+ self.warmup_iters = warmup_iters
+ self.max_iters = max_iters
+ self.initial_lr = initial_lr
+ super().__init__(optimizer, last_iter)
+
+ def get_lr(self):
+ logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}")
+ if self._step_count <= self.warmup_iters:
+ return [
+ self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters
+ for base_lr in self.base_lrs]
+ else:
+ cos_iter = self._step_count - self.warmup_iters
+ cos_max_iter = self.max_iters - self.warmup_iters
+ cos_theta = cos_iter / cos_max_iter * math.pi
+ cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs]
+ return cos_lr
diff --git a/lam/utils/video.py b/lam/utils/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbaaee42aa6102e471dc9479aaab9f22ca7cb5b1
--- /dev/null
+++ b/lam/utils/video.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+import torch
+
+def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False):
+ import imageio
+ # images: torch.tensor (T, C, H, W), 0-1 or numpy: (T, H, W, 3) 0-255
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ frames = []
+ for i in range(images.shape[0]):
+ if isinstance(images, torch.Tensor):
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
+ assert frame.min() >= 0 and frame.max() <= 255, \
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
+ else:
+ frame = images[i]
+ frames.append(frame)
+ frames = np.stack(frames)
+ if gradio_codec:
+ imageio.mimwrite(output_path, frames, fps=fps, quality=10)
+ else:
+ # imageio.mimwrite(output_path, frames, fps=fps, codec='mpeg4', quality=10)
+ imageio.mimwrite(output_path, frames, fps=fps, quality=10)
+
+ if verbose:
+ print(f"Using gradio codec option {gradio_codec}")
+ print(f"Saved video to {output_path}")
+
+
+def save_images2video(img_lst, v_pth, fps):
+ import moviepy.editor as mpy
+ # Convert the list of NumPy arrays to a list of ImageClip objects
+ clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame
+
+ # Concatenate the ImageClips into a single VideoClip
+ video = mpy.concatenate_videoclips(clips, method="compose")
+
+ # Write the VideoClip to a file
+ video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example
+ print("save video to:", v_pth)
+
+
+if __name__ == "__main__":
+ from glob import glob
+ clip_name = "clip1"
+ ptn = f"./assets/sample_motion/export/{clip_name}/images/*.png"
+ images_pths = glob(ptn)
+ import cv2
+ import numpy as np
+ images = [cv2.imread(pth) for pth in images_pths]
+ save_images2video(images, "./assets/sample_mption/export/{clip_name}/video.mp4", 25, True)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..82076889ac7f57378cdd6d77314225e742dba20a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,57 @@
+einops
+roma
+accelerate
+smplx
+iopath
+wheel
+# gradio
+face-detection-tflite
+moviepy==1.0.3
+decord==0.6.0
+diffusers
+dna==0.0.1
+gfpgan==1.3.8
+gsplat==1.4.0
+# huggingface_hub==0.27.0
+huggingface_hub==0.23.2
+imageio==2.19.3
+jaxtyping==0.2.38
+kiui==0.2.14
+kornia==0.7.2
+loguru==0.7.3
+lpips==0.1.4
+matplotlib==3.5.3
+megfile==4.1.0.post2
+numpy==1.23.0
+omegaconf==2.3.0
+open3d==0.19.0
+opencv_python
+opencv_python_headless
+Pillow==11.1.0
+plyfile
+pygltflib==1.16.2
+pyrender==0.1.45
+PyYAML==6.0.1
+rembg==2.0.63
+Requests==2.32.3
+scipy
+setuptools==74.0.0
+taming_transformers_rom1504==0.0.6
+timm==1.0.15
+pymcubes==0.1.6
+
+https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=28bfba084dca52a06c465d7ad0f3cc372c35fc503f3eab881cc17a5fd82914e7
+https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=5ee103c7eb47f8b08837e0e48b178f7ecc91d769d2b61240b90cb5aa2d06ce77
+
+tqdm==4.66.4
+transformers==4.41.2
+trimesh==4.4.9
+typeguard
+xatlas==0.0.9
+imageio-ffmpeg
+rembg[cpu]
+tyro==0.9.17
+pandas==2.2.3
+chumpy==0.70
+nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling
+pydantic==2.8.0
\ No newline at end of file
diff --git a/requirements_lhm.txt b/requirements_lhm.txt
new file mode 100644
index 0000000000000000000000000000000000000000..207884130435e82ff63dea63653ef0418ab15f9b
--- /dev/null
+++ b/requirements_lhm.txt
@@ -0,0 +1,58 @@
+einops
+roma
+accelerate
+smplx
+iopath
+# gradio
+wheel
+# chumpy==0.66
+decord==0.6.0
+diffusers
+dna==0.0.1
+gfpgan==1.3.8
+gsplat==1.4.0
+huggingface_hub==0.23.2
+imageio==2.19.3
+jaxtyping==0.2.38
+kiui==0.2.14
+kornia==0.7.2
+loguru==0.7.3
+lpips==0.1.4
+matplotlib==3.5.3
+megfile==4.1.0.post2
+numpy==1.23.0
+omegaconf==2.3.0
+open3d==0.19.0
+opencv_python
+opencv_python_headless
+Pillow==11.1.0
+plyfile
+pygltflib==1.16.2
+pyrender==0.1.45
+PyYAML==6.0.1
+rembg==2.0.63
+Requests==2.32.3
+scipy
+setuptools==74.0.0
+taming_transformers_rom1504==0.0.6
+timm==1.0.15
+
+# https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=92af92c569de5da937dd1afb45ecfdd598ec1254cf2e49e3d698cb24d71aae14
+# https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=304937b82c933d5155bd04d771f4b187273f67a76050bb4276b521f7e9b4c4e7
+# https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=e213ff8123e20602bd486739ffee4013338b02f9d2e0e4635a2912750854fdbe
+
+https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=28bfba084dca52a06c465d7ad0f3cc372c35fc503f3eab881cc17a5fd82914e7
+https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=5ee103c7eb47f8b08837e0e48b178f7ecc91d769d2b61240b90cb5aa2d06ce77
+
+--no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html
+
+tqdm==4.66.4
+transformers==4.41.2
+trimesh==4.4.9
+typeguard==2.13.3
+xatlas==0.0.9
+imageio-ffmpeg
+rembg[cpu]
+
+./wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
+./wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
diff --git a/requirements_real.txt b/requirements_real.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be1b83ed0f99b6d160517c53cdf1ff7ed94f30ff
--- /dev/null
+++ b/requirements_real.txt
@@ -0,0 +1,48 @@
+einops
+roma
+accelerate
+smplx
+iopath
+# gradio
+chumpy
+decord==0.6.0
+diffusers
+dna==0.0.1
+gfpgan==1.3.8
+gsplat==1.4.0
+huggingface_hub==0.23.2
+imageio==2.19.3
+jaxtyping==0.2.38
+kiui==0.2.14
+kornia==0.7.2
+loguru==0.7.3
+lpips==0.1.4
+matplotlib==3.5.3
+megfile==4.1.0.post2
+numpy==1.23.0
+omegaconf==2.3.0
+open3d==0.19.0
+opencv_python
+opencv_python_headless
+Pillow==11.1.0
+plyfile
+pygltflib==1.16.2
+pyrender==0.1.45
+PyYAML==6.0.1
+rembg==2.0.63
+Requests==2.32.3
+scipy
+setuptools==74.0.0
+taming_transformers_rom1504==0.0.6
+timm==1.0.15
+
+https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=92af92c569de5da937dd1afb45ecfdd598ec1254cf2e49e3d698cb24d71aae14
+https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=304937b82c933d5155bd04d771f4b187273f67a76050bb4276b521f7e9b4c4e7
+https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=e213ff8123e20602bd486739ffee4013338b02f9d2e0e4635a2912750854fdbe
+
+tqdm==4.66.4
+transformers==4.41.2
+trimesh==4.4.9
+typeguard==2.13.3
+xatlas==0.0.9
+imageio-ffmpeg
\ No newline at end of file
diff --git a/scripts/convert_hf.py b/scripts/convert_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..301ac9c81899b2133b0ee84fcd5c973a572bbbb1
--- /dev/null
+++ b/scripts/convert_hf.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import pdb
+import sys
+import traceback
+from tempfile import TemporaryDirectory
+
+import safetensors
+import torch.nn as nn
+from accelerate import Accelerator
+from megfile import (
+ smart_copy,
+ smart_exists,
+ smart_listdir,
+ smart_makedirs,
+ smart_path_join,
+)
+from omegaconf import OmegaConf
+
+sys.path.append(".")
+
+from LHM.models import model_dict
+from LHM.utils.hf_hub import wrap_model_hub
+from LHM.utils.proxy import no_proxy
+
+
+@no_proxy
+def auto_load_model(cfg, model: nn.Module) -> int:
+
+ ckpt_root = smart_path_join(
+ cfg.saver.checkpoint_root,
+ cfg.experiment.parent,
+ cfg.experiment.child,
+ )
+ if not smart_exists(ckpt_root):
+ raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}")
+ ckpt_dirs = smart_listdir(ckpt_root)
+ if len(ckpt_dirs) == 0:
+ raise FileNotFoundError(f"No checkpoint found in {ckpt_root}")
+ ckpt_dirs.sort()
+
+ load_step = (
+ f"{cfg.convert.global_step}"
+ if cfg.convert.global_step is not None
+ else ckpt_dirs[-1]
+ )
+ load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors")
+
+ if load_model_path.startswith("s3"):
+ tmpdir = TemporaryDirectory()
+ tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors")
+ smart_copy(load_model_path, tmp_model_path)
+ load_model_path = tmp_model_path
+
+ print(f"Loading from {load_model_path}")
+ try:
+ safetensors.torch.load_model(model, load_model_path, strict=True)
+ except:
+ traceback.print_exc()
+ safetensors.torch.load_model(model, load_model_path, strict=False)
+
+ return int(load_step)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./assets/config.yaml")
+ args, unknown = parser.parse_known_args()
+ cfg = OmegaConf.load(args.config)
+ cli_cfg = OmegaConf.from_cli(unknown)
+ cfg = OmegaConf.merge(cfg, cli_cfg)
+
+ """
+ [cfg.convert]
+ global_step: int
+ save_dir: str
+ """
+
+ accelerator = Accelerator()
+
+ # hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type])
+ hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"])
+
+ hf_model = hf_model_cls(OmegaConf.to_container(cfg.model))
+ loaded_step = auto_load_model(cfg, hf_model)
+ dump_path = smart_path_join(
+ f"./exps/releases",
+ cfg.experiment.parent,
+ cfg.experiment.child,
+ f"step_{loaded_step:06d}",
+ )
+ print(f"Saving locally to {dump_path}")
+ smart_makedirs(dump_path, exist_ok=True)
+ hf_model.save_pretrained(
+ save_directory=dump_path,
+ config=hf_model.config,
+ )
diff --git a/scripts/exp/run_4gpu.sh b/scripts/exp/run_4gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ea90e0647db9b1180dc27c5959979095c91f3ab4
--- /dev/null
+++ b/scripts/exp/run_4gpu.sh
@@ -0,0 +1,16 @@
+ ACC_CONFIG="./configs/accelerate-train-4gpu.yaml"
+ TRAIN_CONFIG="./configs/train-sample-human.yaml"
+
+ if [ -n "$1" ]; then
+ TRAIN_CONFIG=$1
+ else
+ TRAIN_CONFIG="./configs/train-sample-human.yaml"
+ fi
+
+ if [ -n "$2" ]; then
+ MAIN_PORT=$2
+ else
+ MAIN_PORT=12345
+ fi
+
+ accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
\ No newline at end of file
diff --git a/scripts/exp/run_8gpu.sh b/scripts/exp/run_8gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f6f65a66db97c04adae59f8c0ca3018ba8bd606f
--- /dev/null
+++ b/scripts/exp/run_8gpu.sh
@@ -0,0 +1,16 @@
+ ACC_CONFIG="./configs/accelerate-train.yaml"
+ TRAIN_CONFIG="./configs/train-sample-human.yaml"
+
+ if [ -n "$1" ]; then
+ TRAIN_CONFIG=$1
+ else
+ TRAIN_CONFIG="./configs/train-sample-human.yaml"
+ fi
+
+ if [ -n "$2" ]; then
+ MAIN_PORT=$2
+ else
+ MAIN_PORT=12345
+ fi
+
+ accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
\ No newline at end of file
diff --git a/scripts/exp/run_debug.sh b/scripts/exp/run_debug.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6aa6233149e06b19da6bcd288d04ef91e3d961c0
--- /dev/null
+++ b/scripts/exp/run_debug.sh
@@ -0,0 +1,15 @@
+ ACC_CONFIG="./configs/accelerate-train-1gpu.yaml"
+
+ if [ -n "$1" ]; then
+ TRAIN_CONFIG=$1
+ else
+ TRAIN_CONFIG="./configs/train-sample-human.yaml"
+ fi
+
+ if [ -n "$2" ]; then
+ MAIN_PORT=$2
+ else
+ MAIN_PORT=12345
+ fi
+
+ accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
\ No newline at end of file
diff --git a/scripts/upload_hub.py b/scripts/upload_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..52fba14d95d367a776c45a63fbfef8054b2e1406
--- /dev/null
+++ b/scripts/upload_hub.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import sys
+
+sys.path.append(".")
+
+import argparse
+
+from accelerate import Accelerator
+
+from LHM.models import model_dict
+from LHM.utils.hf_hub import wrap_model_hub
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, required=True)
+ parser.add_argument("--local_ckpt", type=str, required=True)
+ parser.add_argument("--repo_id", type=str, required=True)
+ args, unknown = parser.parse_known_args()
+
+ accelerator = Accelerator()
+
+ hf_model_cls = wrap_model_hub(model_dict[args.model_type])
+ hf_model = hf_model_cls.from_pretrained(args.local_ckpt)
+ hf_model.push_to_hub(
+ repo_id=args.repo_id,
+ config=hf_model.config,
+ private=True,
+ )
diff --git a/vhap/combine_nerf_datasets.py b/vhap/combine_nerf_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..5721cd26831e32a0a9895dc1d8c9f31402207ed7
--- /dev/null
+++ b/vhap/combine_nerf_datasets.py
@@ -0,0 +1,174 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from typing import Optional, Literal, List
+from copy import deepcopy
+import json
+import tyro
+from pathlib import Path
+import shutil
+import random
+
+
+class NeRFDatasetAssembler:
+ def __init__(self, src_folders: List[Path], tgt_folder: Path, division_mode: Literal['random_single', 'random_group', 'last']='random_group'):
+ self.src_folders = src_folders
+ self.tgt_folder = tgt_folder
+ self.num_timestep = 0
+
+ # use the subject name as the random seed to sample the test sequence
+ subjects = [sf.name.split('_')[0] for sf in src_folders]
+ for s in subjects:
+ assert s == subjects[0], f"Cannot combine datasets from different subjects: {subjects}"
+ subject = subjects[0]
+ random.seed(subject)
+
+ if division_mode == 'random_single':
+ self.src_folders_test = [self.src_folders.pop(int(random.uniform(0, 1) * len(src_folders)))]
+ elif division_mode == 'random_group':
+ # sample one sequence as the test sequence every `group_size` sequences
+ self.src_folders_test = []
+ num_all = len(self.src_folders)
+ group_size = 10
+ num_test = max(1, num_all // group_size)
+ indices_test = []
+ for gi in range(num_test):
+ idx = min(num_all - 1, random.randint(0, group_size - 1) + gi * group_size)
+ indices_test.append(idx)
+
+ for idx in indices_test:
+ self.src_folders_test.append(self.src_folders.pop(idx))
+ elif division_mode == 'last':
+ self.src_folders_test = [self.src_folders.pop(-1)]
+ else:
+ raise ValueError(f"Unknown division mode: {division_mode}")
+
+ self.src_folders_train = self.src_folders
+
+ def write(self):
+ self.combine_dbs(self.src_folders_train, division='train')
+ self.combine_dbs(self.src_folders_test, division='test')
+
+ def combine_dbs(self, src_folders, division: Optional[Literal['train', 'test']] = None):
+ db = None
+ for i, src_folder in enumerate(src_folders):
+ dbi_path = src_folder / "transforms.json"
+ assert dbi_path.exists(), f"Could not find {dbi_path}"
+ # print(f"Loading database: {dbi_path}")
+ dbi = json.load(open(dbi_path, "r"))
+
+ dbi['timestep_indices'] = [t + self.num_timestep for t in dbi['timestep_indices']]
+ self.num_timestep += len(dbi['timestep_indices'])
+ for frame in dbi['frames']:
+ # drop keys that are irrelevant for a combined dataset
+ frame.pop('timestep_index_original')
+ frame.pop('timestep_id')
+
+ # accumulate timestep indices
+ frame['timestep_index'] = dbi['timestep_indices'][frame['timestep_index']]
+
+ # complement the parent folder
+ frame['file_path'] = str(Path('..') / Path(src_folder.name) / frame['file_path'])
+ frame['flame_param_path'] = str(Path('..') / Path(src_folder.name) / frame['flame_param_path'])
+ frame['fg_mask_path'] = str(Path('..') / Path(src_folder.name) / frame['fg_mask_path'])
+
+ if db is None:
+ db = dbi
+ else:
+ db['frames'] += dbi['frames']
+ db['timestep_indices'] += dbi['timestep_indices']
+
+ if not self.tgt_folder.exists():
+ self.tgt_folder.mkdir(parents=True)
+
+ if division == 'train':
+ # copy the canonical flame param
+ cano_flame_param_path = src_folders[0] / "canonical_flame_param.npz"
+ tgt_flame_param_path = self.tgt_folder / f"canonical_flame_param.npz"
+ print(f"Copying canonical flame param: {tgt_flame_param_path}")
+ shutil.copy(cano_flame_param_path, tgt_flame_param_path)
+
+ # leave one camera for validation
+ db_train = {k: v for k, v in db.items() if k not in ['frames', 'camera_indices']}
+ db_train['frames'] = []
+ db_val = deepcopy(db_train)
+
+ if len(db['camera_indices']) > 1:
+ # when having multiple cameras, leave one camera for validation (novel-view sythesis)
+ if 8 in db['camera_indices']:
+ # use camera 8 for validation (front-view of the NeRSemble dataset)
+ db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8]
+ db_val['camera_indices'] = [8]
+ else:
+ # use the last camera for validation
+ db_train['camera_indices'] = db['camera_indices'][:-1]
+ db_val['camera_indices'] = [db['camera_indices'][-1]]
+ else:
+ # when only having one camera, we create an empty validation set
+ db_train['camera_indices'] = db['camera_indices']
+ db_val['camera_indices'] = []
+
+ for frame in db['frames']:
+ if frame['camera_index'] in db_train['camera_indices']:
+ db_train['frames'].append(frame)
+ elif frame['camera_index'] in db_val['camera_indices']:
+ db_val['frames'].append(frame)
+ else:
+ raise ValueError(f"Unknown camera index: {frame['camera_index']}")
+
+ write_json(db_train, self.tgt_folder, 'train')
+ write_json(db_val, self.tgt_folder, 'val')
+
+ with open(self.tgt_folder / 'sequences_trainval.txt', 'w') as f:
+ for folder in src_folders:
+ f.write(folder.name + '\n')
+ else:
+ db['timestep_indices'] = sorted(db['timestep_indices'])
+ write_json(db, self.tgt_folder, division)
+
+ with open(self.tgt_folder / f'sequences_{division}.txt', 'w') as f:
+ for folder in src_folders:
+ f.write(folder.name + '\n')
+
+
+def write_json(db, tgt_folder, division=None):
+ fname = "transforms.json" if division is None else f"transforms_{division}.json"
+ json_path = tgt_folder / fname
+ print(f"Writing database: {json_path}")
+ with open(json_path, "w") as f:
+ json.dump(db, f, indent=4)
+
+def main(
+ src_folders: List[Path],
+ tgt_folder: Path,
+ division_mode: Literal['random_single', 'random_group', 'last']='random_group',
+ ):
+ incomplete = False
+ print("==== Begin assembling datasets ====")
+ print(f"Division mode: {division_mode}")
+ for src_folder in src_folders:
+ try:
+ assert src_folder.exists(), f"Error: could not find {src_folder}"
+ assert src_folder.parent == tgt_folder.parent, "All source folders must be in the same parent folder as the target folder"
+ # print(src_folder)
+ except AssertionError as e:
+ print(e)
+ incomplete = True
+
+ if incomplete:
+ return
+
+ nerf_dataset_assembler = NeRFDatasetAssembler(src_folders, tgt_folder, division_mode)
+ nerf_dataset_assembler.write()
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/vhap/config/base.py b/vhap/config/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1824d571e990d276f0908be8e8cd42f9afe9b8cd
--- /dev/null
+++ b/vhap/config/base.py
@@ -0,0 +1,353 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Literal, Tuple
+import tyro
+import importlib
+from vhap.util.log import get_logger
+logger = get_logger(__name__)
+
+
+def import_module(module_name: str):
+ module_name, class_name = module_name.rsplit(".", 1)
+ module = getattr(importlib.import_module(module_name), class_name)
+ return module
+
+
+class Config:
+ def __getitem__(self, __name: str):
+ if hasattr(self, __name):
+ return getattr(self, __name)
+ else:
+ raise AttributeError(f"{self.__class__.__name__} has no attribute '{__name}'")
+
+
+@dataclass()
+class DataConfig(Config):
+ root_folder: Path = ''
+ """The root folder for the dataset."""
+ sequence: str = ''
+ """The sequence name"""
+ _target: str = "vhap.data.video_dataset.VideoDataset"
+ """The target dataset class"""
+ division: Optional[str] = None
+ subset: Optional[str] = None
+ calibrated: bool = False
+ """Whether the cameras parameters are available"""
+ align_cameras_to_axes: bool = True
+ """Adjust how cameras distribute in the space with a global rotation"""
+ camera_convention_conversion: str = 'opencv->opengl'
+ target_extrinsic_type: Literal['w2c', 'c2w'] = 'w2c'
+ n_downsample_rgb: Optional[int] = None
+ """Load from downsampled RGB images to save data IO time"""
+ scale_factor: float = 1.0
+ """Further apply a scaling transformation after the downsampling of RGB"""
+ background_color: Optional[Literal['white', 'black']] = 'white'
+ use_alpha_map: bool = False
+ use_landmark: bool = True
+ landmark_source: Optional[Literal['face-alignment', 'star']] = "star"
+
+
+@dataclass()
+class ModelConfig(Config):
+ n_shape: int = 300
+ n_expr: int = 100
+ n_tex: int = 100
+
+ use_static_offset: bool = False
+ """Optimize static offsets on top of FLAME vertices in the canonical space"""
+ use_dynamic_offset: bool = False
+ """Optimize dynamic offsets on top of the FLAME vertices in the canonical space"""
+ add_teeth: bool = True
+ """Add teeth to the FLAME model"""
+ remove_lip_inside: bool = False
+ """Remove the inner part of the lips from the FLAME model"""
+
+ tex_resolution: int = 2048
+ """The resolution of the extra texture map"""
+ tex_painted: bool = True
+ """Use a painted texture map instead the pca texture space as the base texture map"""
+ tex_extra: bool = True
+ """Optimize an extra texture map as the base texture map or the residual texture map"""
+ # tex_clusters: tuple[str, ...] = ("skin", "hair", "sclerae", "lips_tight", "boundary")
+ tex_clusters: tuple[str, ...] = ("skin", "hair", "boundary", "lips_tight", "teeth", "sclerae", "irises")
+ """Regions that are supposed to share a similar color inside"""
+ residual_tex: bool = True
+ """Use the extra texture map as a residual component on top of the base texture"""
+ occluded: tuple[str, ...] = () # to be used for updating stage configs in __post_init__
+ """The regions that are occluded by the hair or garments"""
+
+ flame_params_path: Optional[Path] = None
+
+
+@dataclass()
+class RenderConfig(Config):
+ backend: Literal['nvdiffrast', 'pytorch3d'] = 'nvdiffrast'
+ """The rendering backend"""
+ use_opengl: bool = False
+ """Use OpenGL for NVDiffRast"""
+ background_train: Literal['white', 'black', 'target'] = 'target'
+ """Background color/image for training"""
+ disturb_rate_fg: Optional[float] = 0.5
+ """The rate of disturbance for the foreground"""
+ disturb_rate_bg: Optional[float] = 0.5
+ """The rate of disturbance for the background. 0.6 best for multi-view, 0.3 best for single-view"""
+ background_eval: Literal['white', 'black', 'target'] = 'target'
+ """Background color/image for evaluation"""
+ lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'SH'
+ """The type of lighting"""
+ lighting_space: Literal['world', 'camera'] = 'world'
+ """The space of lighting"""
+
+
+@dataclass()
+class LearningRateConfig(Config):
+ base: float = 5e-3
+ """shape, texture, rotation, eyes, neck, jaw"""
+ translation: float = 1e-3
+ expr: float = 5e-2
+ static_offset: float = 5e-4
+ dynamic_offset: float = 5e-4
+ camera: float = 5e-3
+ light: float = 5e-3
+
+
+@dataclass()
+class LossWeightConfig(Config):
+ landmark: Optional[float] = 10.
+ always_enable_jawline_landmarks: bool = True
+ """Always enable the landmark loss for the jawline landmarks. Ignore disable_jawline_landmarks in stages."""
+
+ photo: Optional[float] = 30.
+
+ reg_shape: float = 3e-1
+ reg_expr: float = 3e-2
+ reg_tex_pca: float = 1e-4 # will make it hard to model hair color when too high
+
+ reg_tex_res: Optional[float] = None # 1e2 (when w/o reg_var)
+ """Regularize the residual texture map"""
+ reg_tex_res_clusters: Optional[float] = 1e1
+ """Regularize the residual texture map inside each texture cluster"""
+ reg_tex_res_for: tuple[str, ...] = ("sclerae", "teeth")
+ """Regularize the residual texture map for the clusters specified"""
+ reg_tex_tv: Optional[float] = 1e4 # important to split regions apart
+ """Regularize the total variation of the texture map"""
+
+ reg_light: Optional[float] = None
+ """Regularize lighting parameters"""
+ reg_diffuse: Optional[float] = 1e2
+ """Regularize lighting parameters by the diffuse term"""
+
+ reg_offset: Optional[float] = 3e2
+ """Regularize the norm of offsets"""
+ reg_offset_relax_coef: float = 1.
+ """The coefficient for relaxing reg_offset for the regions specified"""
+ reg_offset_relax_for: tuple[str, ...] = ("hair", "ears")
+ """Relax the offset loss for the regions specified"""
+
+ reg_offset_lap: Optional[float] = 1e6
+ """Regularize the difference of laplacian coordinate caused by offsets"""
+ reg_offset_lap_relax_coef: float = 0.1
+ """The coefficient for relaxing reg_offset_lap for the regions specified"""
+ reg_offset_lap_relax_for: tuple[str, ...] = ("hair", "ears")
+ """Relax the offset loss for the regions specified"""
+
+ reg_offset_rigid: Optional[float] = 3e2
+ """Regularize the the offsets to be as-rigid-as-possible"""
+ reg_offset_rigid_for: tuple[str, ...] = ("left_ear", "right_ear", "neck", "left_eye", "right_eye", "lips_tight")
+ """Regularize the the offsets to be as-rigid-as-possible for the regions specified"""
+
+ reg_offset_dynamic: Optional[float] = 3e5
+ """Regularize the dynamic offsets to be temporally smooth"""
+
+ blur_iter: int = 0
+ """The number of iterations for blurring vertex weights"""
+
+ smooth_trans: float = 3e2
+ """global translation"""
+ smooth_rot: float = 3e1
+ """global rotation"""
+
+ smooth_neck: float = 3e1
+ """neck joint"""
+ smooth_jaw: float = 1e-1
+ """jaw joint"""
+ smooth_eyes: float = 0
+ """eyes joints"""
+
+ prior_neck: float = 3e-1
+ """Regularize the neck joint towards neutral"""
+ prior_jaw: float = 3e-1
+ """Regularize the jaw joint towards neutral"""
+ prior_eyes: float = 3e-2
+ """Regularize the eyes joints towards neutral"""
+
+
+@dataclass()
+class LogConfig(Config):
+ interval_scalar: Optional[int] = 100
+ """The step interval of scalar logging. Using an interval of stage_tracking.num_steps // 5 unless specified."""
+ interval_media: Optional[int] = 500
+ """The step interval of media logging. Using an interval of stage_tracking.num_steps unless specified."""
+ image_format: Literal['jpg', 'png'] = 'jpg'
+ """Output image format"""
+ view_indices: Tuple[int, ...] = ()
+ """Manually specify the view indices for log"""
+ max_num_views: int = 3
+ """The maximum number of views for log"""
+ stack_views_in_rows: bool = True
+
+
+@dataclass()
+class ExperimentConfig(Config):
+ output_folder: Path = Path('output/track')
+ reuse_landmarks: bool = True
+ keyframes: Tuple[int, ...] = tuple()
+ photometric: bool = False
+ """enable photometric optimization, otherwise only landmark optimization"""
+
+@dataclass()
+class StageConfig(Config):
+ disable_jawline_landmarks: bool = False
+ """Disable the landmark loss for the jawline landmarks since they are not accurate"""
+
+@dataclass()
+class StageLmkInitRigidConfig(StageConfig):
+ """The stage for initializing the rigid parameters"""
+ num_steps: int = 300
+ optimizable_params: tuple[str, ...] = ("cam", "pose")
+
+@dataclass()
+class StageLmkInitAllConfig(StageConfig):
+ """The stage for initializing all the parameters optimizable with landmark loss"""
+ num_steps: int = 300
+ optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
+
+@dataclass()
+class StageLmkSequentialTrackingConfig(StageConfig):
+ """The stage for sequential tracking with landmark loss"""
+ num_steps: int = 50
+ optimizable_params: tuple[str, ...] = ("pose", "joints", "expr")
+
+@dataclass()
+class StageLmkGlobalTrackingConfig(StageConfig):
+ """The stage for global tracking with landmark loss"""
+ num_epochs: int = 0
+ optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
+
+@dataclass()
+class PhotometricStageConfig(StageConfig):
+ align_texture_except: tuple[str, ...] = ()
+ """Align the inner region of rendered FLAME to the image, except for the regions specified"""
+ align_boundary_except: tuple[str, ...] = ("bottomline",) # necessary to avoid the bottomline of FLAME from being stretched to the bottom of the image
+ """Align the boundary of FLAME to the image, except for the regions specified"""
+
+@dataclass()
+class StageRgbInitTextureConfig(PhotometricStageConfig):
+ """The stage for initializing the texture map with photometric loss"""
+ num_steps: int = 500
+ optimizable_params: tuple[str, ...] = ("cam", "shape", "texture", "lights")
+ align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
+ align_boundary_except: tuple[str, ...] = ("hair", "boundary")
+
+@dataclass()
+class StageRgbInitAllConfig(PhotometricStageConfig):
+ """The stage for initializing all the parameters except the offsets with photometric loss"""
+ num_steps: int = 500
+ optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights")
+ disable_jawline_landmarks: bool = True
+ align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
+ align_boundary_except: tuple[str, ...] = ("hair", "bottomline")
+
+@dataclass()
+class StageRgbInitOffsetConfig(PhotometricStageConfig):
+ """The stage for initializing the offsets with photometric loss"""
+ num_steps: int = 500
+ optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset")
+ disable_jawline_landmarks: bool = True
+ align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
+
+@dataclass()
+class StageRgbSequentialTrackingConfig(PhotometricStageConfig):
+ """The stage for sequential tracking with photometric loss"""
+ num_steps: int = 50
+ optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "texture", "dynamic_offset")
+ disable_jawline_landmarks: bool = True
+
+@dataclass()
+class StageRgbGlobalTrackingConfig(PhotometricStageConfig):
+ """The stage for global tracking with photometric loss"""
+ num_epochs: int = 30
+ optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset", "dynamic_offset")
+ disable_jawline_landmarks: bool = True
+
+@dataclass()
+class PipelineConfig(Config):
+ lmk_init_rigid: StageLmkInitRigidConfig
+ lmk_init_all: StageLmkInitAllConfig
+ lmk_sequential_tracking: StageLmkSequentialTrackingConfig
+ lmk_global_tracking: StageLmkGlobalTrackingConfig
+ rgb_init_texture: StageRgbInitTextureConfig
+ rgb_init_all: StageRgbInitAllConfig
+ rgb_init_offset: StageRgbInitOffsetConfig
+ rgb_sequential_tracking: StageRgbSequentialTrackingConfig
+ rgb_global_tracking: StageRgbGlobalTrackingConfig
+
+
+@dataclass()
+class BaseTrackingConfig(Config):
+ data: DataConfig
+ model: ModelConfig
+ render: RenderConfig
+ log: LogConfig
+ exp: ExperimentConfig
+ lr: LearningRateConfig
+ w: LossWeightConfig
+ pipeline: PipelineConfig
+
+ begin_stage: Optional[str] = None
+ """Begin from the specified stage for debugging"""
+ begin_frame_idx: int = 0
+ """Begin from the specified frame index for debugging"""
+ async_func: bool = True
+ """Allow asynchronous function calls for speed up"""
+ device: Literal['cuda', 'cpu'] = 'cuda'
+
+ def get_occluded(self):
+ occluded_table = {
+ }
+ if self.data.sequence in occluded_table:
+ logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.sequence]}")
+ self.model.occluded = occluded_table[self.data.sequence]
+
+ def __post_init__(self):
+ self.get_occluded()
+
+ if not self.model.use_static_offset and not self.model.use_dynamic_offset:
+ self.model.occluded = tuple(list(self.model.occluded) + ['hair']) # disable boundary alignment for the hair region if no offset is used
+
+ for cfg_stage in self.pipeline.__dict__.values():
+ if isinstance(cfg_stage, PhotometricStageConfig):
+ cfg_stage.align_texture_except = tuple(list(cfg_stage.align_texture_except) + list(self.model.occluded))
+ cfg_stage.align_boundary_except = tuple(list(cfg_stage.align_boundary_except) + list(self.model.occluded))
+
+ if self.begin_stage is not None:
+ skip = True
+ for cfg_stage in self.pipeline.__dict__.values():
+ if cfg_stage.__class__.__name__.lower() == self.begin_stage:
+ skip = False
+ if skip:
+ cfg_stage.num_steps = 0
+
+
+if __name__ == "__main__":
+ config = tyro.cli(BaseTrackingConfig)
+ print(tyro.to_yaml(config))
\ No newline at end of file
diff --git a/vhap/config/nersemble.py b/vhap/config/nersemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ab75b105ad2e271e3f5807d02649a0116fa3a9e
--- /dev/null
+++ b/vhap/config/nersemble.py
@@ -0,0 +1,86 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from typing import Optional, Literal
+from dataclasses import dataclass
+import tyro
+
+from vhap.config.base import (
+ StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig,
+ DataConfig, LossWeightConfig, BaseTrackingConfig,
+)
+from vhap.util.log import get_logger
+logger = get_logger(__name__)
+
+
+@dataclass()
+class NersembleDataConfig(DataConfig):
+ _target: str = "vhap.data.nersemble_dataset.NeRSembleDataset"
+ calibrated: bool = True
+ image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200)
+ """(height, width). Will be use to convert principle points when the image size is not included in the camera parameters."""
+ background_color: Optional[Literal['white', 'black']] = None
+ landmark_source: Optional[Literal["face-alignment", 'star']] = "star"
+
+ subject: str = ""
+ """Subject ID. Such as 018, 218, 251, 253"""
+ use_color_correction: bool = True
+ """Whether to use color correction to harmonize the color of the input images."""
+
+@dataclass()
+class NersembleLossWeightConfig(LossWeightConfig):
+ landmark: Optional[float] = 3. # should not be lower to avoid collapse
+ always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work
+ reg_expr: float = 1e-2 # for best expressivness
+ reg_tex_tv: Optional[float] = 1e5 # 10x of the base value
+
+@dataclass()
+class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig):
+ optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset")
+
+ align_texture_except: tuple[str, ...] = ("boundary",)
+ align_boundary_except: tuple[str, ...] = ("boundary",)
+ """Due to the limited flexibility in the lower neck region of FLAME, we relax the
+ alignment constraints for better alignment in the face region.
+ """
+
+@dataclass()
+class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig):
+ align_texture_except: tuple[str, ...] = ("boundary",)
+ align_boundary_except: tuple[str, ...] = ("boundary",)
+ """Due to the limited flexibility in the lower neck region of FLAME, we relax the
+ alignment constraints for better alignment in the face region.
+ """
+
+@dataclass()
+class NersemblePipelineConfig(PipelineConfig):
+ rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig
+ rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig
+
+@dataclass()
+class NersembleTrackingConfig(BaseTrackingConfig):
+ data: NersembleDataConfig
+ w: NersembleLossWeightConfig
+ pipeline: NersemblePipelineConfig
+
+ def get_occluded(self):
+ occluded_table = {
+ '018': ('neck_lower',),
+ '218': ('neck_lower',),
+ '251': ('neck_lower', 'boundary'),
+ '253': ('neck_lower',),
+ }
+ if self.data.subject in occluded_table:
+ logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}")
+ self.model.occluded = occluded_table[self.data.subject]
+
+
+if __name__ == "__main__":
+ config = tyro.cli(NersembleTrackingConfig)
+ print(tyro.to_yaml(config))
\ No newline at end of file
diff --git a/vhap/data/image_folder_dataset.py b/vhap/data/image_folder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d657265c1c4a00b1d482fda2dc1e6f3b27f4d7
--- /dev/null
+++ b/vhap/data/image_folder_dataset.py
@@ -0,0 +1,79 @@
+from pathlib import Path
+from typing import Optional
+import numpy as np
+import PIL.Image as Image
+from torch.utils.data import Dataset
+from vhap.util.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class ImageFolderDataset(Dataset):
+ def __init__(
+ self,
+ image_folder: Path,
+ background_folder: Optional[Path]=None,
+ background_fname2camId=lambda x: x,
+ image_fname2camId=lambda x: x,
+ ):
+ """
+ Args:
+ root_folder: Path to dataset with the following directory layout
+ /
+ |---xx.jpg
+ |---...
+ """
+ super().__init__()
+ self.image_fname2camId = image_fname2camId
+ self.background_foler = background_folder
+
+ logger.info(f"Initializing dataset from folder {image_folder}")
+
+ self.image_paths = sorted(list(image_folder.glob('*.jpg')))
+
+ if background_folder is not None:
+ self.backgrounds = {}
+ background_paths = sorted(list((image_folder / background_folder).glob('*.jpg')))
+
+ for background_path in background_paths:
+ bg = np.array(Image.open(background_path))
+ cam_id = background_fname2camId(background_path.name)
+ self.backgrounds[cam_id] = bg
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def __getitem__(self, i):
+ image_path = self.image_paths[i]
+ cam_id = self.image_fname2camId(image_path.name)
+ rgb = np.array(Image.open(image_path))
+ item = {
+ "rgb": rgb,
+ 'image_path': str(image_path),
+ }
+
+ if self.background_foler is not None:
+ item['background'] = self.backgrounds[cam_id]
+
+ return item
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+ from torch.utils.data import DataLoader
+
+ dataset = ImageFolderDataset(
+ image_folder='./xx',
+ img_to_tensor=True,
+ )
+
+ print(len(dataset))
+
+ sample = dataset[0]
+ print(sample.keys())
+ print(sample["rgb"].shape)
+
+ dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
+ for item in tqdm(dataloader):
+ pass
diff --git a/vhap/data/nerf_dataset.py b/vhap/data/nerf_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..09175c6e0c3b2ae43258489eaef78a4eff980344
--- /dev/null
+++ b/vhap/data/nerf_dataset.py
@@ -0,0 +1,161 @@
+from pathlib import Path
+import json
+import numpy as np
+import PIL.Image as Image
+import torch
+import torchvision.transforms.functional as F
+from torch.utils.data import Dataset
+from vhap.util.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class NeRFDataset(Dataset):
+ def __init__(
+ self,
+ root_folder,
+ division=None,
+ camera_convention_conversion=None,
+ target_extrinsic_type='w2c',
+ use_fg_mask=False,
+ use_flame_param=False,
+ ):
+ """
+ Args:
+ root_folder: Path to dataset with the following directory layout
+ /
+ |
+ |---/
+ | |---00000.jpg
+ | |...
+ |
+ |---/
+ | |---00000.png
+ | |...
+ |
+ |---/
+ | |---00000.npz
+ | |...
+ |
+ |---transforms_backup.json # backup of the original transforms.json
+ |---transforms_backup_flame.json # backup of the original transforms.json with flame_param
+ |---transforms.json # the final transforms.json
+ |---transforms_train.json # the final transforms.json for training
+ |---transforms_val.json # the final transforms.json for validation
+ |---transforms_test.json # the final transforms.json for testing
+
+
+ """
+
+ super().__init__()
+ self.root_folder = Path(root_folder)
+ self.division = division
+ self.camera_convention_conversion = camera_convention_conversion
+ self.target_extrinsic_type = target_extrinsic_type
+ self.use_fg_mask = use_fg_mask
+ self.use_flame_param = use_flame_param
+
+ logger.info(f"Loading NeRF scene from: {root_folder}")
+
+ # data division
+ if division is None:
+ tranform_path = self.root_folder / "transforms.json"
+ elif division == "train":
+ tranform_path = self.root_folder / "transforms_train.json"
+ elif division == "val":
+ tranform_path = self.root_folder / "transforms_val.json"
+ elif division == "test":
+ tranform_path = self.root_folder / "transforms_test.json"
+ else:
+ raise NotImplementedError(f"Unknown division type: {division}")
+ logger.info(f"division: {division}")
+
+ self.transforms = json.load(open(tranform_path, "r"))
+ logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}")
+
+ assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1
+
+ def __len__(self):
+ return len(self.transforms['frames'])
+
+ def __getitem__(self, i):
+ frame = self.transforms['frames'][i]
+
+ # 'timestep_index', 'timestep_index_original', 'timestep_id', 'camera_index', 'camera_id', 'cx', 'cy', 'fl_x', 'fl_y', 'h', 'w', 'camera_angle_x', 'camera_angle_y', 'transform_matrix', 'file_path', 'fg_mask_path', 'flame_param_path']
+
+ K = torch.eye(3)
+ K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor(
+ [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]]
+ )
+
+ c2w = torch.tensor(frame['transform_matrix'])
+ if self.target_extrinsic_type == "w2c":
+ extrinsic = c2w.inverse()
+ elif self.target_extrinsic_type == "c2w":
+ extrinsic = c2w
+ else:
+ raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}")
+
+ img_path = self.root_folder / frame['file_path']
+
+ item = {
+ 'timestep_index': frame['timestep_index'],
+ 'camera_index': frame['camera_index'],
+ 'intrinsics': K,
+ 'extrinsics': extrinsic,
+ 'image_height': frame['h'],
+ 'image_width': frame['w'],
+ 'image': np.array(Image.open(img_path)),
+ 'image_path': img_path,
+ }
+
+ if self.use_fg_mask and 'fg_mask_path' in frame:
+ fg_mask_path = self.root_folder / frame['fg_mask_path']
+ item["fg_mask"] = np.array(Image.open(fg_mask_path))
+ item["fg_mask_path"] = fg_mask_path
+
+ if self.use_flame_param and 'flame_param_path' in frame:
+ npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True)
+ item["flame_param"] = dict(npz)
+
+ return item
+
+ def apply_to_tensor(self, item):
+ if self.img_to_tensor:
+ if "rgb" in item:
+ item["rgb"] = F.to_tensor(item["rgb"])
+ # if self.rgb_range_shift:
+ # item["rgb"] = (item["rgb"] - 0.5) / 0.5
+
+ if "alpha_map" in item:
+ item["alpha_map"] = F.to_tensor(item["alpha_map"])
+ return item
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+ from dataclasses import dataclass
+ import tyro
+ from torch.utils.data import DataLoader
+
+ @dataclass
+ class Args:
+ root_folder: str
+ subject: str
+ sequence: str
+ use_landmark: bool = False
+ batchify_all_views: bool = False
+
+ args = tyro.cli(Args)
+
+ dataset = NeRFDataset(root_folder=args.root_folder)
+
+ print(len(dataset))
+
+ sample = dataset[0]
+ print(sample.keys())
+
+ dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
+ for item in tqdm(dataloader):
+ pass
diff --git a/vhap/data/nersemble_dataset.py b/vhap/data/nersemble_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de1e050f557ab580453d517c18af0d8d9a77b6f
--- /dev/null
+++ b/vhap/data/nersemble_dataset.py
@@ -0,0 +1,183 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import json
+import numpy as np
+import torch
+from vhap.data.video_dataset import VideoDataset
+from vhap.config.nersemble import NersembleDataConfig
+from vhap.util import camera
+from vhap.util.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class NeRSembleDataset(VideoDataset):
+ def __init__(
+ self,
+ cfg: NersembleDataConfig,
+ img_to_tensor: bool = False,
+ batchify_all_views: bool = False,
+ ):
+ """
+ Args:
+ root_folder: Path to dataset with the following directory layout
+ /
+ |---camera_params/
+ | |---/
+ | |---camera_params.json
+ |
+ |---color_correction/
+ | |---/
+ | |---.npy
+ |
+ |---/
+ |---/
+ |---images/
+ | |---cam__.jpg
+ |
+ |---alpha_maps/
+ | |---cam__.png
+ |
+ |---landmark2d/
+ |---face-alignment/
+ | |---.npz
+ |
+ |---STAR/
+ |---.npz
+ """
+ self.cfg = cfg
+ assert cfg.subject != "", "Please specify the subject name"
+
+ super().__init__(
+ cfg=cfg,
+ img_to_tensor=img_to_tensor,
+ batchify_all_views=batchify_all_views,
+ )
+
+ def match_sequences(self):
+ logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}")
+ return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject).glob(f"{self.cfg.sequence}*")))
+
+ def define_properties(self):
+ super().define_properties()
+ self.properties['rgb']['cam_id_prefix'] = "cam_"
+ self.properties['alpha_map']['cam_id_prefix'] = "cam_"
+
+ def load_camera_params(self):
+ load_path = self.cfg.root_folder / "camera_params" / self.cfg.subject / "camera_params.json"
+ assert load_path.exists()
+ param = json.load(open(load_path))
+
+ K = torch.Tensor(param["intrinsics"])
+
+ if "height" not in param or "width" not in param:
+ assert self.cfg.image_size_during_calibration is not None
+ H, W = self.cfg.image_size_during_calibration
+ else:
+ H, W = param["height"], param["width"]
+
+ self.camera_ids = list(param["world_2_cam"].keys())
+ w2c = torch.tensor([param["world_2_cam"][k] for k in self.camera_ids]) # (N, 4, 4)
+ R = w2c[..., :3, :3]
+ T = w2c[..., :3, 3]
+
+ orientation = R.transpose(-1, -2) # (N, 3, 3)
+ location = R.transpose(-1, -2) @ -T[..., None] # (N, 3, 1)
+
+ # adjust how cameras distribute in the space with a global rotation
+ if self.cfg.align_cameras_to_axes:
+ orientation, location = camera.align_cameras_to_axes(
+ orientation, location, target_convention="opengl"
+ )
+
+ # modify the local orientation of cameras to fit in different camera conventions
+ if self.cfg.camera_convention_conversion is not None:
+ orientation, K = camera.convert_camera_convention(
+ self.cfg.camera_convention_conversion, orientation, K, H, W
+ )
+
+ c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
+
+ if self.cfg.target_extrinsic_type == "w2c":
+ R = orientation.transpose(-1, -2)
+ T = orientation.transpose(-1, -2) @ -location
+ w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
+ extrinsic = w2c
+ elif self.cfg.target_extrinsic_type == "c2w":
+ extrinsic = c2w
+ else:
+ raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
+
+ self.camera_params = {}
+ for i, camera_id in enumerate(self.camera_ids):
+ self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
+
+ def filter_division(self, division):
+ if division is not None:
+ cam_for_train = [8, 7, 9, 4, 10, 5, 13, 2, 12, 1, 14, 0]
+ if division == "train":
+ self.camera_ids = [
+ self.camera_ids[i]
+ for i in range(len(self.camera_ids))
+ if i in cam_for_train
+ ]
+ elif division == "val":
+ self.camera_ids = [
+ self.camera_ids[i]
+ for i in range(len(self.camera_ids))
+ if i not in cam_for_train
+ ]
+ elif division == "front-view":
+ self.camera_ids = self.camera_ids[8:9]
+ elif division == "side-view":
+ self.camera_ids = self.camera_ids[0:1]
+ elif division == "six-view":
+ self.camera_ids = [self.camera_ids[i] for i in [0, 1, 7, 8, 14, 15]]
+ else:
+ raise NotImplementedError(f"Unknown division type: {division}")
+ logger.info(f"division: {division}")
+
+ def apply_transforms(self, item):
+ if self.cfg.use_color_correction:
+ color_correction_path = self.cfg.root_folder / 'color_correction' / self.cfg.subject / f'{item["camera_id"]}.npy'
+ affine_color_transform = np.load(color_correction_path)
+ rgb = item["rgb"] / 255
+ rgb = rgb @ affine_color_transform[:3, :3] + affine_color_transform[np.newaxis, :3, 3]
+ item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
+
+ super().apply_transforms(item)
+ return item
+
+
+if __name__ == "__main__":
+ import tyro
+ from tqdm import tqdm
+ from torch.utils.data import DataLoader
+ from vhap.config.nersemble import NersembleDataConfig
+ from vhap.config.base import import_module
+
+ cfg = tyro.cli(NersembleDataConfig)
+ cfg.use_landmark = False
+ dataset = import_module(cfg._target)(
+ cfg=cfg,
+ img_to_tensor=False,
+ batchify_all_views=True,
+ )
+
+ print(len(dataset))
+
+ sample = dataset[0]
+ print(sample.keys())
+ print(sample["rgb"].shape)
+
+ dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
+ for item in tqdm(dataloader):
+ pass
diff --git a/vhap/data/video_dataset.py b/vhap/data/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8f4a0cfdf100f8e53a2c3cfefa2c47439999e7
--- /dev/null
+++ b/vhap/data/video_dataset.py
@@ -0,0 +1,418 @@
+import os
+from pathlib import Path
+from copy import deepcopy
+from typing import Optional
+import numpy as np
+import PIL.Image as Image
+import torch
+import torchvision.transforms.functional as F
+from torch.utils.data import Dataset, default_collate
+import json
+from vhap.util.log import get_logger
+from vhap.config.base import DataConfig
+
+
+logger = get_logger(__name__)
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ cfg: DataConfig,
+ img_to_tensor: bool = False,
+ batchify_all_views: bool = False,
+ ):
+ """
+ Args:
+ root_folder: Path to dataset with the following directory layout
+ /
+ |---images/
+ | |---.jpg
+ |
+ |---alpha_maps/
+ | |---.png
+ |
+ |---landmark2d/
+ |---face-alignment/
+ | |---.npz
+ |
+ |---STAR/
+ |---.npz
+ """
+ super().__init__()
+ self.cfg = cfg
+ self.img_to_tensor = img_to_tensor
+ self.batchify_all_views = batchify_all_views
+
+ sequence_paths = self.match_sequences()
+ if len(sequence_paths) > 1:
+ logger.info(f"Found multiple sequences: {sequence_paths}")
+ raise ValueError(f"Found multiple sequences by '{cfg.sequence}': \n" + "\n\t".join([str(x) for x in sequence_paths]))
+ elif len(sequence_paths) == 0:
+ raise ValueError(f"Cannot find sequence: {cfg.sequence}")
+ self.sequence_path = sequence_paths[0]
+ logger.info(f"Initializing dataset from {self.sequence_path}")
+
+ self.define_properties()
+ self.load_camera_params()
+
+ # timesteps
+ self.timestep_ids = set(
+ f.split('.')[0].split('_')[-1]
+ for f in os.listdir(self.sequence_path / self.properties['rgb']['folder']) if f.endswith(self.properties['rgb']['suffix'])
+ )
+ self.timestep_ids = sorted(self.timestep_ids)
+ self.timestep_indices = list(range(len(self.timestep_ids)))
+
+ self.filter_division(cfg.division)
+ self.filter_subset(cfg.subset)
+
+ logger.info(f"number of timesteps: {self.num_timesteps}, number of cameras: {self.num_cameras}")
+
+ # collect
+ self.items = []
+ for fi, timestep_index in enumerate(self.timestep_indices):
+ for ci, camera_id in enumerate(self.camera_ids):
+ self.items.append(
+ {
+ "timestep_index": fi, # new index after filtering
+ "timestep_index_original": timestep_index, # original index
+ "timestep_id": self.timestep_ids[timestep_index],
+ "camera_index": ci,
+ "camera_id": camera_id,
+ }
+ )
+
+ def match_sequences(self):
+ logger.info(f"Looking for sequence '{self.cfg.sequence}' at {self.cfg.root_folder}")
+ return list(filter(lambda x: x.is_dir(), self.cfg.root_folder.glob(f"{self.cfg.sequence}*")))
+
+ def define_properties(self):
+ self.properties = {
+ "rgb": {
+ "folder": f"images_{self.cfg.n_downsample_rgb}"
+ if self.cfg.n_downsample_rgb
+ else "images",
+ "per_timestep": True,
+ # "suffix": "jpg",
+ "suffix": "png",
+ },
+ "alpha_map": {
+ "folder": "alpha_maps",
+ "per_timestep": True,
+ "suffix": "jpg",
+ },
+ "landmark2d/face-alignment": {
+ "folder": "landmark2d/face-alignment",
+ "per_timestep": False,
+ "suffix": "npz",
+ },
+ "landmark2d/STAR": {
+ "folder": "landmark2d/STAR",
+ "per_timestep": False,
+ "suffix": "npz",
+ },
+ "landmark2d/lms": {
+ "folder": "landmark2d/landmarks",
+ "per_timestep": False,
+ "suffix": "npz",
+ },
+ }
+
+ @staticmethod
+ def get_number_after_prefix(string, prefix):
+ i = string.find(prefix)
+ if i != -1:
+ number_begin = i + len(prefix)
+ assert number_begin < len(string), f"No number found behind prefix '{prefix}'"
+ assert string[number_begin].isdigit(), f"No number found behind prefix '{prefix}'"
+
+ non_digit_indices = [i for i, c in enumerate(string[number_begin:]) if not c.isdigit()]
+ if len(non_digit_indices) > 0:
+ number_end = number_begin + min(non_digit_indices)
+ return int(string[number_begin:number_end])
+ else:
+ return int(string[number_begin:])
+ else:
+ return None
+
+ def filter_division(self, division):
+ pass
+
+ def filter_subset(self, subset):
+ if subset is not None:
+ if 'ti' in subset:
+ ti = self.get_number_after_prefix(subset, 'ti')
+ if 'tj' in subset:
+ tj = self.get_number_after_prefix(subset, 'tj')
+ self.timestep_indices = self.timestep_indices[ti:tj+1]
+ else:
+ self.timestep_indices = self.timestep_indices[ti:ti+1]
+ elif 'tn' in subset:
+ tn = self.get_number_after_prefix(subset, 'tn')
+ tn_all = len(self.timestep_indices)
+ tn = min(tn, tn_all)
+ self.timestep_indices = self.timestep_indices[::tn_all // tn][:tn]
+ elif 'ts' in subset:
+ ts = self.get_number_after_prefix(subset, 'ts')
+ self.timestep_indices = self.timestep_indices[::ts]
+ if 'ci' in subset:
+ ci = self.get_number_after_prefix(subset, 'ci')
+ self.camera_ids = self.camera_ids[ci:ci+1]
+ elif 'cn' in subset:
+ cn = self.get_number_after_prefix(subset, 'cn')
+ cn_all = len(self.camera_ids)
+ cn = min(cn, cn_all)
+ self.camera_ids = self.camera_ids[::cn_all // cn][:cn]
+ elif 'cs' in subset:
+ cs = self.get_number_after_prefix(subset, 'cs')
+ self.camera_ids = self.camera_ids[::cs]
+
+ def load_camera_params(self):
+ self.camera_ids = ['0']
+
+ # Guessed focal length, height, width. Should be optimized or replaced by real values
+ f, h, w = 512, 512, 512
+ K = torch.Tensor([
+ [f, 0, w],
+ [0, f, h],
+ [0, 0, 1]
+ ])
+
+ orientation = torch.eye(3)[None, ...] # (1, 3, 3)
+ location = torch.Tensor([0, 0, 1])[None, ..., None] # (1, 3, 1)
+
+ c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
+
+ if self.cfg.target_extrinsic_type == "w2c":
+ R = orientation.transpose(-1, -2)
+ T = orientation.transpose(-1, -2) @ -location
+ w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
+ extrinsic = w2c
+ elif self.cfg.target_extrinsic_type == "c2w":
+ extrinsic = c2w
+ else:
+ raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
+
+ self.camera_params = {}
+ for i, camera_id in enumerate(self.camera_ids):
+ self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
+
+ return self.camera_params
+
+ def __len__(self):
+ if self.batchify_all_views:
+ return self.num_timesteps
+ else:
+ return len(self.items)
+
+ def __getitem__(self, i):
+ if self.batchify_all_views:
+ return self.getitem_by_timestep(i)
+ else:
+ return self.getitem_single_image(i)
+
+ def getitem_single_image(self, i):
+ item = deepcopy(self.items[i])
+
+ rgb_path = self.get_property_path("rgb", i)
+ item["rgb"] = np.array(Image.open(rgb_path))[:, :, :3]
+
+ camera_param = self.camera_params[item["camera_id"]]
+ item["intrinsic"] = camera_param["intrinsic"].clone()
+ item["extrinsic"] = camera_param["extrinsic"].clone()
+
+ if self.cfg.use_alpha_map or self.cfg.background_color is not None:
+ alpha_path = self.get_property_path("alpha_map", i)
+ item["alpha_map"] = np.array(Image.open(alpha_path))
+
+ if self.cfg.use_landmark:
+ timestep_index = self.items[i]["timestep_index"]
+
+ landmark_path = self.get_property_path("landmark2d/lms", i)
+ landmark_npz = np.load(landmark_path)
+
+ lms_eyes_path = os.path.join(os.path.dirname(landmark_path),'iris.json')
+
+ item["lmk2d"] = landmark_npz["face_landmark_2d"][timestep_index] # (num_points, 3)
+ if (item["lmk2d"][:, :2] == -1).sum() > 0:
+ item["lmk2d"][:, 2:] = 0.0
+ else:
+ item["lmk2d"][:, 2:] = 1.0
+
+ if(os.path.exists(lms_eyes_path)):
+ with open(lms_eyes_path,'r') as f:
+ lms_eye = json.load(f)
+ lms_eye = np.array([lms_eye[key] for key in lms_eye][timestep_index]).reshape((2,2)) / 1024.
+ lms_eye = np.concatenate([lms_eye,np.ones((2,1))],axis=1)[(1,0),:]
+ item["lmk2d"] = np.concatenate([item["lmk2d"], lms_eye], 0)
+ else:
+ item["lmk2d"] = np.concatenate([item["lmk2d"]], 0)
+
+ item = self.apply_transforms(item)
+ return item
+
+ def getitem_by_timestep(self, timestep_index):
+ begin = timestep_index * self.num_cameras
+ indices = range(begin, begin + self.num_cameras)
+ item = default_collate([self.getitem_single_image(i) for i in indices])
+
+ item["num_cameras"] = self.num_cameras
+ return item
+
+ def apply_transforms(self, item):
+ item = self.apply_scale_factor(item)
+ item = self.apply_background_color(item)
+ item = self.apply_to_tensor(item)
+ return item
+
+ def apply_to_tensor(self, item):
+ if self.img_to_tensor:
+ if "rgb" in item:
+ item["rgb"] = F.to_tensor(item["rgb"])
+
+ if "alpha_map" in item:
+ item["alpha_map"] = F.to_tensor(item["alpha_map"])
+ return item
+
+ def apply_scale_factor(self, item):
+ assert self.cfg.scale_factor <= 1.0
+
+ if "rgb" in item:
+ H, W, _ = item["rgb"].shape
+ h, w = int(H * self.cfg.scale_factor), int(W * self.cfg.scale_factor)
+ rgb = Image.fromarray(item["rgb"]).resize(
+ (w, h), resample=Image.BILINEAR
+ )
+ item["rgb"] = np.array(rgb)
+
+ # properties that are defined based on image size
+ if "lmk2d" in item:
+ item["lmk2d"][..., 0] *= w
+ item["lmk2d"][..., 1] *= h
+
+ if "lmk2d_iris" in item:
+ item["lmk2d_iris"][..., 0] *= w
+ item["lmk2d_iris"][..., 1] *= h
+
+ if "bbox_2d" in item:
+ item["bbox_2d"][[0, 2]] *= w
+ item["bbox_2d"][[1, 3]] *= h
+
+ # properties need to be scaled down when rgb is downsampled
+ n_downsample_rgb = self.cfg.n_downsample_rgb if self.cfg.n_downsample_rgb else 1
+ scale_factor = self.cfg.scale_factor / n_downsample_rgb
+ item["scale_factor"] = scale_factor # NOTE: not self.cfg.scale_factor
+ if scale_factor < 1.0:
+ if "intrinsic" in item:
+ item["intrinsic"][:2] *= scale_factor
+ if "alpha_map" in item:
+ h, w = item["rgb"].shape[:2]
+ alpha_map = Image.fromarray(item["alpha_map"]).resize(
+ (w, h), Image.Resampling.BILINEAR
+ )
+ item["alpha_map"] = np.array(alpha_map)
+ return item
+
+ def apply_background_color(self, item):
+ if self.cfg.background_color is not None:
+ assert (
+ "alpha_map" in item
+ ), "'alpha_map' is required to apply background color."
+ fg = item["rgb"]
+ if self.cfg.background_color == "white":
+ bg = np.ones_like(fg) * 255
+ elif self.cfg.background_color == "black":
+ bg = np.zeros_like(fg)
+ else:
+ raise NotImplementedError(
+ f"Unknown background color: {self.cfg.background_color}."
+ )
+
+ # w = item["alpha_map"][..., None] / 255
+ w = item["alpha_map"] / 255
+ img = (w * fg + (1 - w) * bg).astype(np.uint8)
+ item["rgb"] = img
+ return item
+
+ def get_property_path(
+ self,
+ name,
+ index: Optional[int] = None,
+ timestep_id: Optional[str] = None,
+ camera_id: Optional[str] = None,
+ ):
+ p = self.properties[name]
+ folder = p["folder"] if "folder" in p else None
+ per_timestep = p["per_timestep"]
+ suffix = p["suffix"]
+
+ path = self.sequence_path
+ if folder is not None:
+ path = path / folder
+
+ if self.num_cameras > 1:
+ if camera_id is None:
+ assert (
+ index is not None), "index is required when camera_id is not provided."
+ camera_id = self.items[index]["camera_id"]
+ if "cam_id_prefix" in p:
+ camera_id = p["cam_id_prefix"] + camera_id
+ else:
+ camera_id = ""
+
+ if per_timestep:
+ if timestep_id is None:
+ assert index is not None, "index is required when timestep_id is not provided."
+ timestep_id = self.items[index]["timestep_id"]
+ if len(camera_id) > 0:
+ path /= f"{camera_id}_{timestep_id}.{suffix}"
+ else:
+ path /= f"{timestep_id}.{suffix}"
+ else:
+ if len(camera_id) > 0:
+ path /= f"{camera_id}.{suffix}"
+ else:
+ path = Path(str(path) + f".{suffix}")
+
+ return path
+
+ def get_property_path_list(self, name):
+ paths = []
+ for i in range(len(self.items)):
+ img_path = self.get_property_path(name, i)
+ paths.append(img_path)
+ return paths
+
+ @property
+ def num_timesteps(self):
+ return len(self.timestep_indices)
+
+ @property
+ def num_cameras(self):
+ return len(self.camera_ids)
+
+
+if __name__ == "__main__":
+ import tyro
+ from tqdm import tqdm
+ from torch.utils.data import DataLoader
+ from vhap.config.base import DataConfig, import_module
+
+ cfg = tyro.cli(DataConfig)
+ cfg.use_landmark = False
+ dataset = import_module(cfg._target)(
+ cfg=cfg,
+ img_to_tensor=False,
+ batchify_all_views=True,
+ )
+
+ print(len(dataset))
+
+ sample = dataset[0]
+ print(sample.keys())
+ print(sample["rgb"].shape)
+
+ dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
+ for item in tqdm(dataloader):
+ pass
diff --git a/vhap/export_as_nerf_dataset.py b/vhap/export_as_nerf_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ead0dbe06c500c613639b7182aa00a6505970927
--- /dev/null
+++ b/vhap/export_as_nerf_dataset.py
@@ -0,0 +1,657 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import math
+from typing import Optional, Literal, Dict, List
+from glob import glob
+import concurrent.futures
+import multiprocessing
+from copy import deepcopy
+import yaml
+import json
+import tyro
+from pathlib import Path
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+import torchvision
+# from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
+
+from vhap.config.base import DataConfig, ModelConfig, import_module
+from vhap.data.nerf_dataset import NeRFDataset
+from vhap.model.flame import FlameHead
+from vhap.util.mesh import get_obj_content
+from vhap.util.render_nvdiffrast import NVDiffRenderer
+
+# to prevent "OSError: [Errno 24] Too many open files"
+import torch.multiprocessing
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+
+max_threads = min(multiprocessing.cpu_count(), 8)
+
+
+class NeRFDatasetWriter:
+ def __init__(self, cfg_data: DataConfig, tgt_folder: Path, subset:Optional[str]=None, scale_factor: Optional[float]=None, background_color: Optional[str]=None):
+ self.cfg_data = cfg_data
+ self.tgt_folder = tgt_folder
+
+ print("==== Config: data ====")
+ print(tyro.to_yaml(cfg_data))
+
+ cfg_data.target_extrinsic_type = 'c2w'
+ cfg_data.background_color = 'white'
+ cfg_data.use_alpha_map = True
+ dataset = import_module(cfg_data._target)(cfg=cfg_data)
+ self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=lambda x: x, num_workers=0)
+
+ def write(self):
+ if not self.tgt_folder.exists():
+ self.tgt_folder.mkdir(parents=True)
+
+ db = {
+ "frames": [],
+ }
+
+ print(f"Writing images to {self.tgt_folder}")
+ worker_args = []
+ timestep_indices = set()
+ camera_indices = set()
+ for i, item in tqdm(enumerate(self.dataloader), total=len(self.dataloader)):
+ # print(item.keys())
+
+ timestep_indices.add(item['timestep_index'])
+ camera_indices.add(item['camera_index'])
+
+ extrinsic = item['extrinsic']
+ transform_matrix = torch.cat([extrinsic, torch.tensor([[0,0,0,1]])], dim=0).numpy()
+
+ intrinsic = item['intrinsic'].double().numpy()
+
+ cx = intrinsic[0, 2]
+ cy = intrinsic[1, 2]
+ fl_x = intrinsic[0, 0]
+ fl_y = intrinsic[1, 1]
+ h = item['rgb'].shape[0]
+ w = item['rgb'].shape[1]
+ angle_x = math.atan(w / (fl_x * 2)) * 2
+ angle_y = math.atan(h / (fl_y * 2)) * 2
+
+ frame_item = {
+ "timestep_index": item['timestep_index'],
+ "timestep_index_original": item['timestep_index_original'],
+ "timestep_id": item['timestep_id'],
+ "camera_index": item['camera_index'],
+ "camera_id": item['camera_id'],
+
+ "cx": cx,
+ "cy": cy,
+ "fl_x": fl_x,
+ "fl_y": fl_y,
+ "h": h,
+ "w": w,
+ "camera_angle_x": angle_x,
+ "camera_angle_y": angle_y,
+
+ "transform_matrix": transform_matrix.tolist(),
+
+ "file_path": f"images/{item['timestep_index']:05d}_{item['camera_index']:02d}.png",
+ }
+
+ path2data = {
+ str(self.tgt_folder / frame_item['file_path']): item['rgb'],
+ }
+
+ if 'alpha_map' in item:
+ frame_item['fg_mask_path'] = f"fg_masks/{item['timestep_index']:05d}_{item['camera_index']:02d}.png"
+ path2data[str(self.tgt_folder / frame_item['fg_mask_path'])] = item['alpha_map']
+
+ db['frames'].append(frame_item)
+ worker_args.append([path2data])
+
+ #--- no threading
+ # if len(worker_args) > 0:
+ # write_data(path2data)
+
+ #--- threading
+ if len(worker_args) == max_threads or i == len(self.dataloader)-1:
+ with concurrent.futures.ThreadPoolExecutor(max_threads) as executor:
+ futures = [executor.submit(write_data, *args) for args in worker_args]
+ concurrent.futures.wait(futures)
+ worker_args = []
+
+ # add shared intrinsic parameters to be compatible with other nerf libraries
+ db.update({
+ "cx": cx,
+ "cy": cy,
+ "fl_x": fl_x,
+ "fl_y": fl_y,
+ "h": h,
+ "w": w,
+ "camera_angle_x": angle_x,
+ "camera_angle_y": angle_y
+ })
+
+ # add indices to ease filtering
+ db['timestep_indices'] = sorted(list(timestep_indices))
+ db['camera_indices'] = sorted(list(camera_indices))
+
+ write_json(db, self.tgt_folder)
+ write_json(db, self.tgt_folder, division='backup')
+
+
+class TrackedFLAMEDatasetWriter:
+ def __init__(self, cfg_model: ModelConfig, src_folder: Path, tgt_folder: Path, mode: Literal['mesh', 'param'], epoch: int = -1):
+ print("---- Config: model ----")
+ print(tyro.to_yaml(cfg_model))
+
+ self.cfg_model = cfg_model
+ self.src_folder = src_folder
+ self.tgt_folder = tgt_folder
+ self.mode = mode
+
+ db_backup_path = tgt_folder / "transforms_backup.json"
+ assert db_backup_path.exists(), f"Could not find {db_backup_path}"
+ print(f"Loading database from: {db_backup_path}")
+ self.db = json.load(open(db_backup_path, "r"))
+
+ paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))]
+ epochs = [int(p.stem.split('_')[-1]) for p in paths]
+ if epoch == -1:
+ index = np.argmax(epochs)
+ else:
+ index = epochs.index(epoch)
+ flame_params_path = paths[index]
+
+ assert flame_params_path.exists(), f"Could not find {flame_params_path}"
+ print(f"Loading FLAME parameters from: {flame_params_path}")
+ self.flame_params = dict(np.load(flame_params_path))
+
+ if "focal_length" in self.flame_params:
+ self.focal_length = self.flame_params['focal_length'].item()
+ else:
+ self.focal_length = None
+
+ # Relocate FLAME to the origin and return the transformation matrix to modify camera poses.
+ self.M = self.relocate_flame_meshes(self.flame_params)
+
+ print("Initializing FLAME model...")
+ self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True)
+
+ def relocate_flame_meshes(self, flame_param):
+ """ Relocate FLAME to the origin and return the transformation matrix to modify camera poses. """
+ # Rs = torch.tensor(flame_param['rotation'])
+ Ts = torch.tensor(flame_param['translation'])
+
+ # R_mean = axis_angle_to_matrix(Rs.mean(0))
+ T_mean = Ts.mean(0)
+ M = torch.eye(4)
+ # M[:3, :3] = R_mean.transpose(-1, -2)
+ M[:3, 3] = -T_mean
+
+ # flame_param['rotation'] = (matrix_to_axis_angle(M[None, :3, :3] @ axis_angle_to_matrix(Rs))).numpy()
+ flame_param['translation'] = (M[:3, 3] + Ts).numpy()
+ return M.numpy()
+
+ def replace_cam_params(self, item):
+ c2w = np.eye(4)
+ c2w[2, 3] = 1 # place the camera at (0, 0, 1) in the world coordinate by default
+ item['transform_matrix'] = c2w
+
+ h = item['h']
+ w = item['w']
+ fl_x = self.focal_length * max(h, w)
+ fl_y = self.focal_length * max(h, w)
+ angle_x = math.atan(w / (fl_x * 2)) * 2
+ angle_y = math.atan(h / (fl_y * 2)) * 2
+
+ item.update({
+ "cx": w / 2,
+ "cy": h / 2,
+ "fl_x": fl_x,
+ "fl_y": fl_y,
+ "camera_angle_x": angle_x,
+ "camera_angle_y": angle_y,
+
+ "transform_matrix": c2w.tolist(),
+ })
+
+ def write(self):
+ if self.mode == 'mesh':
+ self.write_canonical_mesh()
+ indices = self.db['timestep_indices']
+ verts = infer_flame_params(self.flame_model, self.flame_params, indices)
+
+ print(f"Writing FLAME expressions and meshes to: {self.tgt_folder}")
+ elif self.mode == 'param':
+ self.write_canonical_flame_param()
+ print(f"Writing FLAME parameters to: {self.tgt_folder}")
+
+ saved = [False] * len(self.db['timestep_indices']) # avoid writing the same mesh multiple times
+ num_processes = 0
+ worker_args = []
+ for i, frame in tqdm(enumerate(self.db['frames']), total=len(self.db['frames'])):
+ if self.focal_length is not None:
+ self.replace_cam_params(frame)
+ # modify the camera extrinsics to place the tracked FLAME at the origin
+ frame['transform_matrix'] = (self.M @ np.array(frame['transform_matrix'])).tolist()
+
+ ti_orig = frame['timestep_index_original'] # use ti_orig when loading FLAME parameters
+ ti = frame['timestep_index'] # use ti when saving files
+
+ # write FLAME mesh or parameters
+ if self.mode == 'mesh':
+ frame['exp_path'] = f"flame/exp/{ti:05d}.txt"
+ frame['mesh_path'] = f"meshes/{ti:05d}.obj"
+ if not saved[ti]:
+ worker_args.append([self.tgt_folder, frame['exp_path'], self.flame_params['expr'][ti_orig], frame['mesh_path'], verts[ti_orig], self.flame_model.faces])
+ saved[ti] = True
+ func = self.write_expr_and_mesh
+ elif self.mode == 'param':
+ frame['flame_param_path'] = f"flame_param/{ti:05d}.npz"
+ if not saved[ti]:
+ worker_args.append([self.tgt_folder, frame['flame_param_path'], self.flame_params, ti_orig])
+ saved[ti] = True
+ func = self.write_flame_param
+ #--- no multiprocessing
+ if len(worker_args) > 0:
+ func(*worker_args.pop())
+ #--- multiprocessing
+ # if len(worker_args) == num_processes or i == len(self.db['frames'])-1:
+ # pool = multiprocessing.Pool(processes=num_processes)
+ # pool.starmap(func, worker_args)
+ # pool.close()
+ # pool.join()
+ # worker_args = []
+
+ write_json(self.db, self.tgt_folder)
+ write_json(self.db, self.tgt_folder, division='backup_flame')
+
+ def write_canonical_mesh(self):
+ print(f"Inferencing FLAME in the canonical space...")
+ if 'static_offset' in self.flame_params:
+ static_offset = torch.tensor(self.flame_params['static_offset'])
+ else:
+ static_offset = None
+ with torch.no_grad():
+ ret = self.flame_model(
+ torch.tensor(self.flame_params['shape'])[None, ...],
+ torch.zeros(*self.flame_params['expr'][:1].shape),
+ torch.zeros(*self.flame_params['rotation'][:1].shape),
+ torch.zeros(*self.flame_params['neck_pose'][:1].shape),
+ torch.tensor([[0.3, 0, 0]]),
+ torch.zeros(*self.flame_params['eyes_pose'][:1].shape),
+ torch.zeros(*self.flame_params['translation'][:1].shape),
+ return_verts_cano=False,
+ static_offset=static_offset,
+ )
+ verts = ret[0]
+
+ cano_mesh_path = self.tgt_folder / 'canonical.obj'
+ print(f"Writing canonical mesh to: {cano_mesh_path}")
+ obj_data = get_obj_content(verts[0], self.flame_model.faces)
+ write_data({cano_mesh_path: obj_data})
+
+ @staticmethod
+ def write_expr_and_mesh(tgt_folder, exp_path, expr, mesh_path, verts, faces):
+ path2data = {}
+
+ expr_data = '\n'.join([str(n) for n in expr])
+ path2data[tgt_folder / exp_path] = expr_data
+
+ obj_data = get_obj_content(verts, faces)
+ path2data[tgt_folder / mesh_path] = obj_data
+ write_data(path2data)
+
+ def write_canonical_flame_param(self):
+ flame_param = {
+ 'translation': np.zeros_like(self.flame_params['translation'][:1]),
+ 'rotation': np.zeros_like(self.flame_params['rotation'][:1]),
+ 'neck_pose': np.zeros_like(self.flame_params['neck_pose'][:1]),
+ 'jaw_pose': np.array([[0.3, 0, 0]]), # open mouth
+ 'eyes_pose': np.zeros_like(self.flame_params['eyes_pose'][:1]),
+ 'shape': self.flame_params['shape'],
+ 'expr': np.zeros_like(self.flame_params['expr'][:1]),
+ }
+ if 'static_offset' in self.flame_params:
+ flame_param['static_offset'] = self.flame_params['static_offset']
+
+ cano_flame_param_path = self.tgt_folder / 'canonical_flame_param.npz'
+ print(f"Writing canonical FLAME parameters to: {cano_flame_param_path}")
+ write_data({cano_flame_param_path: flame_param})
+
+ @staticmethod
+ def write_flame_param(tgt_folder, flame_param_path, flame_params, tid):
+ params = {
+ 'translation': flame_params['translation'][[tid]],
+ 'rotation': flame_params['rotation'][[tid]],
+ 'neck_pose': flame_params['neck_pose'][[tid]],
+ 'jaw_pose': flame_params['jaw_pose'][[tid]],
+ 'eyes_pose': flame_params['eyes_pose'][[tid]],
+ 'shape': flame_params['shape'],
+ 'expr': flame_params['expr'][[tid]],
+ }
+
+ if 'static_offset' in flame_params:
+ params['static_offset'] = flame_params['static_offset']
+ if 'dynamic_offset' in flame_params:
+ params['dynamic_offset'] = flame_params['dynamic_offset'][[tid]]
+
+ path2data = {tgt_folder / flame_param_path: params}
+ write_data(path2data)
+
+class MaskFromFLAME:
+ def __init__(self, cfg_model: ModelConfig, tgt_folder, background_color: str) -> None:
+ background_color = self.cfg_data.background_color if background_color is None else background_color
+ if background_color == 'white':
+ self.background_tensor = torch.tensor([255, 255, 255]).byte()
+ elif background_color == 'black':
+ self.background_tensor = torch.tensor([0, 0, 0]).byte()
+ else:
+ raise ValueError(f"Unknown background color: {background_color}")
+
+ dataset = NeRFDataset(
+ root_folder=tgt_folder,
+ division=None,
+ camera_convention_conversion=None,
+ target_extrinsic_type='w2c',
+ use_fg_mask=True,
+ use_flame_param=True,
+ )
+ self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=None, num_workers=0)
+
+ self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True)
+
+ self.mesh_renderer = NVDiffRenderer(use_opengl=False)
+
+ @torch.no_grad()
+ def write(self):
+ t2verts = {}
+ worker_args = []
+ print(f"Generating masks from FLAME...")
+ for i, frame in enumerate(tqdm(self.dataloader)):
+
+ # get FLAME vertices
+ timestep = frame['timestep_index']
+ if timestep not in t2verts:
+ t2verts[timestep] = infer_flame_params(self.flame_model, frame['flame_param'], [0]).cuda()
+ verts = t2verts[timestep]
+
+ # render to get forground mask
+ RT = frame['extrinsics'].cuda()[None]
+ K = frame['intrinsics'].cuda()[None]
+ h = frame['image_height']
+ w = frame['image_width']
+
+ # mask = self.get_mask(verts, RT, K, h, w)
+ mask = self.get_mask_tilted_line(verts, RT, K, h, w)
+
+ # edit the image and mask with dilated FLAME mask
+ img = frame['image'].cuda()
+ img = img * mask[:, :, None] + self.background_tensor.cuda()[None, None, :] * (1-mask)[:, :, None]
+
+ # overwrite the original images
+ path2data = {
+ str(frame['image_path']): img.byte().cpu().numpy(),
+ }
+
+ if 'fg_mask_path' in frame and 'fg_mask' in frame:
+ fg_mask = frame['fg_mask'].cuda()
+ fg_mask = fg_mask * mask
+
+ # overwrite the original masks
+ path2data.update({
+ str(frame['fg_mask_path']): fg_mask.byte().cpu().numpy(),
+ })
+
+ # # write to new folder
+ # path2data.update({
+ # str(frame['fg_mask_path']).replace('fg_masks', 'fg_masks_'): fg_mask.byte().cpu().numpy(),
+ # })
+
+ write_data(path2data)
+ worker_args.append([path2data])
+
+ #--- no threading
+ # if len(worker_args) > 0:
+ # write_data(path2data)
+
+ #--- threading
+ if len(worker_args) == max_threads or i == len(self.dataloader)-1:
+ with concurrent.futures.ThreadPoolExecutor(max_threads) as executor:
+ futures = [executor.submit(write_data, *args) for args in worker_args]
+ concurrent.futures.wait(futures)
+ worker_args = []
+
+ def get_mask(self, verts, RT, K, h, w):
+ faces = self.flame_model.faces.cuda()
+ out_dict = self.mesh_renderer.render_without_texture(verts, faces, RT, K, (h, w))
+
+ rgba_mesh = out_dict['rgba'].squeeze(0) # (H, W, C)
+ mask_mesh = rgba_mesh[..., 3] # (H, W)
+
+ # get the bottom line of the neck and disable mask for the upper part
+ verts_clip = out_dict['verts_clip'][0]
+ verts_ndc = verts_clip[:, :3] / verts_clip[:, -1:]
+ xy = verts_ndc[:, :2]
+ xy[:, 1] = -xy[:, 1]
+ xy = (xy * 0.5 + 0.5) * torch.tensor([[h, w]]).cuda()
+ vid_ring = self.flame_model.mask.get_vid_by_region(['neck_top'])
+ xy_ring = xy[vid_ring]
+ bottom_line = int(xy_ring[:, 1].min().item())
+
+ mask = mask_mesh.clone()
+ mask[:bottom_line] = 1
+
+ # anti-aliasing with gaussian kernel
+ k = int(0.02 * w)//2 * 2 + 1
+ blur = torchvision.transforms.GaussianBlur(k, sigma=k)
+ mask = blur(mask[None])[0] #.clamp(0, 1)
+ return mask
+
+ def get_mask_tilted_line(self, verts, RT, K, h, w):
+ verts_ndc = self.mesh_renderer.world_to_ndc(verts, RT, K, (h, w), flip_y=True)
+
+ verts_xy = verts_ndc[0, :, :2]
+ verts_xy = (verts_xy * 0.5 + 0.5) * torch.tensor([w, h]).cuda()
+
+ verts_xy_left = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_right_point'])]
+ verts_xy_right = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_left_point'])]
+ verts_xy_bottom = verts_xy[self.flame_model.mask.get_vid_by_region(['front_middle_bottom_point_boundary'])]
+
+ delta_xy = verts_xy_left - verts_xy_right
+ assert (delta_xy[:, 0] != 0).all()
+ k = delta_xy[:, 1] / delta_xy[:, 0]
+ b = verts_xy_bottom[:, 1] - k * verts_xy_bottom[:, 0]
+
+ x = torch.arange(w).cuda()
+ y = torch.arange(h).cuda()
+ yx = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1)
+
+ mask = ((k * yx[:, :, 1] + b - yx[:, :, 0]) > 0).float()
+
+ # anti-aliasing with gaussian kernel
+ k = int(0.03 * w)//2 * 2 + 1
+ blur = torchvision.transforms.GaussianBlur(k, sigma=k)
+ mask = blur(mask[None])[0] #.clamp(0, 1)
+ return mask
+
+def infer_flame_params(flame_model: FlameHead, flame_params: Dict, indices:List):
+ if 'static_offset' in flame_params:
+ static_offset = flame_params['static_offset']
+ if isinstance(static_offset, np.ndarray):
+ static_offset = torch.tensor(static_offset)
+ else:
+ static_offset = None
+ for k in flame_params:
+ if isinstance(flame_params[k], np.ndarray):
+ flame_params[k] = torch.tensor(flame_params[k])
+ with torch.no_grad():
+ ret = flame_model(
+ flame_params['shape'][None, ...].expand(len(indices), -1),
+ flame_params['expr'][indices],
+ flame_params['rotation'][indices],
+ flame_params['neck_pose'][indices],
+ flame_params['jaw_pose'][indices],
+ flame_params['eyes_pose'][indices],
+ flame_params['translation'][indices],
+ return_verts_cano=False,
+ static_offset=static_offset,
+ )
+ verts = ret[0]
+ return verts
+
+
+
+def write_json(db, tgt_folder, division=None):
+ fname = "transforms.json" if division is None else f"transforms_{division}.json"
+ json_path = tgt_folder / fname
+ print(f"Writing database: {json_path}")
+ with open(json_path, "w") as f:
+ json.dump(db, f, indent=4)
+
+def write_data(path2data):
+ for path, data in path2data.items():
+ path = Path(path)
+ if not path.parent.exists():
+ path.parent.mkdir(parents=True, exist_ok=True)
+
+ if path.suffix in [".png", ".jpg"]:
+ Image.fromarray(data).save(path)
+ elif path.suffix in [".obj"]:
+ with open(path, "w") as f:
+ f.write(data)
+ elif path.suffix in [".txt"]:
+ with open(path, "w") as f:
+ f.write(data)
+ elif path.suffix in [".npz"]:
+ np.savez(path, **data)
+ else:
+ raise NotImplementedError(f"Unknown file type: {path.suffix}")
+
+def split_json(tgt_folder: Path, train_ratio=0.7):
+ db = json.load(open(tgt_folder / "transforms.json", "r"))
+
+ # init db for each division
+ db_train = {k: v for k, v in db.items() if k not in ['frames', 'timestep_indices', 'camera_indices']}
+ db_train['frames'] = []
+ db_val = deepcopy(db_train)
+ db_test = deepcopy(db_train)
+
+ # divide timesteps
+ nt = len(db['timestep_indices'])
+ assert 0 < train_ratio <= 1
+ nt_train = int(np.ceil(nt * train_ratio))
+ nt_test = nt - nt_train
+
+ # record number of timesteps
+ timestep_indices = sorted(db['timestep_indices'])
+ db_train['timestep_indices'] = timestep_indices[:nt_train]
+ db_val['timestep_indices'] = timestep_indices[:nt_train] # validation set share the same timesteps with training set
+ db_test['timestep_indices'] = timestep_indices[nt_train:]
+
+ if len(db['camera_indices']) > 1:
+ # when having multiple cameras, leave one camera for validation (novel-view sythesis)
+ if 8 in db['camera_indices']:
+ # use camera 8 for validation (front-view of the NeRSemble dataset)
+ db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8]
+ db_val['camera_indices'] = [8]
+ db_test['camera_indices'] = db['camera_indices']
+ else:
+ # use the last camera for validation
+ db_train['camera_indices'] = db['camera_indices'][:-1]
+ db_val['camera_indices'] = [db['camera_indices'][-1]]
+ db_test['camera_indices'] = db['camera_indices']
+ else:
+ # when only having one camera, we create an empty validation set
+ db_train['camera_indices'] = db['camera_indices']
+ db_val['camera_indices'] = []
+ db_test['camera_indices'] = db['camera_indices']
+
+ # fill data by timestep index
+ range_train = range(db_train['timestep_indices'][0], db_train['timestep_indices'][-1]+1) if nt_train > 0 else []
+ range_test = range(db_test['timestep_indices'][0], db_test['timestep_indices'][-1]+1) if nt_test > 0 else []
+ for f in db['frames']:
+ if f['timestep_index'] in range_train:
+ if f['camera_index'] in db_train['camera_indices']:
+ db_train['frames'].append(f)
+ elif f['camera_index'] in db_val['camera_indices']:
+ db_val['frames'].append(f)
+ else:
+ raise ValueError(f"Unknown camera index: {f['camera_index']}")
+ elif f['timestep_index'] in range_test:
+ db_test['frames'].append(f)
+ assert f['camera_index'] in db_test['camera_indices'], f"Unknown camera index: {f['camera_index']}"
+ else:
+ raise ValueError(f"Unknown timestep index: {f['timestep_index']}")
+
+ write_json(db_train, tgt_folder, division='train')
+ write_json(db_val, tgt_folder, division='val')
+ write_json(db_test, tgt_folder, division='test')
+
+def load_config(src_folder: Path):
+ config_path = src_folder / "config.yml"
+ if not config_path.exists():
+ src_folder = sorted(src_folder.iterdir())[-1]
+ config_path = src_folder / "config.yml"
+ assert config_path.exists(), f"File not found: {config_path}"
+
+ cfg = yaml.load(config_path.read_text(), Loader=yaml.Loader)
+ # assert isinstance(cfg, BaseTrackingConfig)
+ return src_folder, cfg
+
+def check_epoch(src_folder: Path, epoch: int):
+ paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))]
+ epochs = [int(p.stem.split('_')[-1]) for p in paths]
+ if epoch == -1:
+ index = np.argmax(epochs)
+ else:
+ try:
+ index = epochs.index(epoch)
+ except ValueError:
+ raise ValueError(f"Could not find epoch {epoch} in {src_folder}")
+
+def main(
+ src_folder: Path,
+ tgt_folder: Path,
+ subset: Optional[str]=None,
+ scale_factor: Optional[float]=None,
+ background_color: Optional[str]=None,
+ flame_mode: Literal['mesh', 'param']='param',
+ create_mask_from_mesh: bool=False,
+ epoch: int=-1,
+ ):
+ print(f"Begin exportation from {src_folder}")
+ assert src_folder.exists(), f"Folder not found: {src_folder}"
+ src_folder, cfg = load_config(src_folder)
+
+ check_epoch(src_folder, epoch)
+
+ if epoch != -1:
+ tgt_folder = Path(str(tgt_folder) + f"_epoch{epoch}")
+
+ nerf_dataset_writer = NeRFDatasetWriter(cfg.data, tgt_folder, subset, scale_factor, background_color)
+ nerf_dataset_writer.write()
+
+ flame_dataset_writer = TrackedFLAMEDatasetWriter(cfg.model, src_folder, tgt_folder, mode=flame_mode, epoch=epoch)
+ flame_dataset_writer.write()
+
+ if create_mask_from_mesh:
+ mask_generator = MaskFromFLAME(cfg.model, tgt_folder, background_color)
+ mask_generator.write()
+
+ split_json(tgt_folder)
+
+ print("Finshed!")
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
\ No newline at end of file
diff --git a/vhap/flame_editor.py b/vhap/flame_editor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bdf5ecdbc90bf4937ca6c66f7916726c27cfe76
--- /dev/null
+++ b/vhap/flame_editor.py
@@ -0,0 +1,362 @@
+import tyro
+from dataclasses import dataclass
+from typing import Optional
+from pathlib import Path
+import time
+import dearpygui.dearpygui as dpg
+import numpy as np
+import torch
+
+from vhap.util.camera import OrbitCamera
+from vhap.model.flame import FlameHead
+from vhap.config.base import ModelConfig
+from vhap.util.render_nvdiffrast import NVDiffRenderer
+
+
+@dataclass
+class Config:
+ model: ModelConfig
+ """FLAME model configuration"""
+ param_path: Optional[Path] = None
+ """Path to the npz file for FLAME parameters"""
+ W: int = 1024
+ """GUI width"""
+ H: int = 1024
+ """GUI height"""
+ radius: float = 1
+ """default GUI camera radius from center"""
+ fovy: float = 30
+ """default GUI camera fovy"""
+ background_color: tuple[float] = (1., 1., 1.)
+ """default GUI background color"""
+ use_opengl: bool = False
+ """use OpenGL or CUDA rasterizer"""
+
+
+class FlameViewer:
+ def __init__(self, cfg: Config):
+ self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters.
+
+ # flame model
+ self.flame_model = FlameHead(
+ cfg.model.n_shape,
+ cfg.model.n_expr,
+ add_teeth=True,
+ include_lbs_color=True,
+ )
+ self.reset_flame_param()
+
+ # viewer settings
+ self.W = cfg.W
+ self.H = cfg.H
+ self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl")
+ self.last_time_fresh = None
+ self.render_mode = '-'
+ self.selected_regions = '-'
+ self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
+ self.need_update = True # camera moved, should reset accumulation
+
+ # buffers for mouse interaction
+ self.cursor_x = None
+ self.cursor_y = None
+ self.drag_begin_x = None
+ self.drag_begin_y = None
+ self.drag_button = None
+
+ # rendering settings
+ self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera')
+
+ self.define_gui()
+
+ def __del__(self):
+ dpg.destroy_context()
+
+ def refresh(self):
+ dpg.set_value("_texture", self.render_buffer)
+
+ if self.last_time_fresh is not None:
+ elapsed = time.time() - self.last_time_fresh
+ fps = 1 / elapsed
+ dpg.set_value("_log_fps", f'{fps:.1f}')
+ self.last_time_fresh = time.time()
+
+ def define_gui(self):
+ dpg.create_context()
+
+ # register texture =================================================================================================
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+
+ # register window ==================================================================================================
+ # the window to display the rendered image
+ with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True):
+ dpg.add_image("_texture", width=self.W, height=self.H, tag="_image")
+
+ # control window ==================================================================================================
+ with dpg.window(label="Control", tag="_control_window", autosize=True):
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("FPS: ")
+ dpg.add_text("", tag="_log_fps")
+
+ # rendering options
+ with dpg.collapsing_header(label="Render", default_open=True):
+
+ def callback_set_render_mode(sender, app_data):
+ self.render_mode = app_data
+ self.need_update = True
+ dpg.add_combo(('-', 'lbs weights'), label='render mode', default_value=self.render_mode, tag="_combo_render_mode", callback=callback_set_render_mode)
+
+ def callback_select_regions(sender, app_data):
+ self.selected_regions = app_data
+ self.need_update = True
+ dpg.add_combo(['-']+sorted(self.flame_model.mask.v.keys()), label='regions', default_value='-', tag="_combo_regions", callback=callback_select_regions)
+
+ # fov slider
+ def callback_set_fovy(sender, app_data):
+ self.cam.fovy = app_data
+ self.need_update = True
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy")
+
+ def callback_reset_camera(sender, app_data):
+ self.cam.reset()
+ self.need_update = True
+ dpg.set_value("_slider_fovy", self.cam.fovy)
+
+ with dpg.group(horizontal=True):
+ dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera)
+
+
+ # FLAME paraemter options
+ with dpg.collapsing_header(label="Parameters", default_open=True):
+
+ def callback_set_pose(sender, app_data):
+ joint, axis = sender.split('-')[1:3]
+ axis_idx = {'x': 0, 'y': 1, 'z': 2}[axis]
+ self.flame_param[joint][0, axis_idx] = app_data
+ self.need_update = True
+ self.pose_sliders = []
+ slider_width = 87
+ for joint in ['neck', 'jaw']:
+ dpg.add_text(f'{joint:9s}')
+ if joint in self.flame_param:
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="x", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 0], callback=callback_set_pose, tag=f"_slider-{joint}-x", width=slider_width)
+ dpg.add_slider_float(label="y", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 1], callback=callback_set_pose, tag=f"_slider-{joint}-y", width=slider_width)
+ dpg.add_slider_float(label="z", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 2], callback=callback_set_pose, tag=f"_slider-{joint}-z", width=slider_width)
+ self.pose_sliders.append(f"_slider-{joint}-x")
+ self.pose_sliders.append(f"_slider-{joint}-y")
+ self.pose_sliders.append(f"_slider-{joint}-z")
+
+ def callback_set_expr(sender, app_data):
+ expr_i = int(sender.split('-')[2])
+ self.flame_param['expr'][0, expr_i] = app_data
+ self.need_update = True
+ self.expr_sliders = []
+ dpg.add_text(f'expr')
+ for i in range(5):
+ dpg.add_slider_float(label=f"{i}", min_value=-5, max_value=5, format="%.2f", default_value=0, callback=callback_set_expr, tag=f"_slider-expr-{i}", width=300)
+ self.expr_sliders.append(f"_slider-expr-{i}")
+
+ def callback_reset_flame(sender, app_data):
+ self.reset_flame_param()
+ self.need_update = True
+ for slider in self.pose_sliders + self.expr_sliders:
+ dpg.set_value(slider, 0)
+ dpg.add_button(label="reset FLAME", tag="_button_reset_flame", callback=callback_reset_flame)
+
+ ### register mouse handlers ========================================================================================
+
+ def callback_mouse_move(sender, app_data):
+ self.cursor_x, self.cursor_y = app_data
+ if not dpg.is_item_focused("_render_window"):
+ return
+
+ if self.drag_begin_x is None or self.drag_begin_y is None:
+ self.drag_begin_x = self.cursor_x
+ self.drag_begin_y = self.cursor_y
+ else:
+ dx = self.cursor_x - self.drag_begin_x
+ dy = self.cursor_y - self.drag_begin_y
+
+ # button=dpg.mvMouseButton_Left
+ if self.drag_button is dpg.mvMouseButton_Left:
+ self.cam.orbit(dx, dy)
+ self.need_update = True
+ elif self.drag_button is dpg.mvMouseButton_Middle:
+ self.cam.pan(dx, dy)
+ self.need_update = True
+
+ def callback_mouse_button_down(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+ self.drag_begin_x = self.cursor_x
+ self.drag_begin_y = self.cursor_y
+ self.drag_button = app_data[0]
+
+ def callback_mouse_release(sender, app_data):
+ self.drag_begin_x = None
+ self.drag_begin_y = None
+ self.drag_button = None
+
+ self.dx_prev = None
+ self.dy_prev = None
+
+ def callback_mouse_drag(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+
+ button, dx, dy = app_data
+ if self.dx_prev is None or self.dy_prev is None:
+ ddx = dx
+ ddy = dy
+ else:
+ ddx = dx - self.dx_prev
+ ddy = dy - self.dy_prev
+
+ self.dx_prev = dx
+ self.dy_prev = dy
+
+ if ddx != 0 and ddy != 0:
+ if button is dpg.mvMouseButton_Left:
+ self.cam.orbit(ddx, ddy)
+ self.need_update = True
+ elif button is dpg.mvMouseButton_Middle:
+ self.cam.pan(ddx, ddy)
+ self.need_update = True
+
+ def callback_camera_wheel_scale(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+ delta = app_data
+ self.cam.scale(delta)
+ self.need_update = True
+
+ with dpg.handler_registry():
+ # this registry order helps avoid false fire
+ dpg.add_mouse_release_handler(callback=callback_mouse_release)
+ # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point
+ dpg.add_mouse_move_handler(callback=callback_mouse_move)
+ dpg.add_mouse_down_handler(callback=callback_mouse_button_down)
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
+
+ # key press handlers
+ # dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left')
+ # dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right')
+ # dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home')
+ # dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End')
+
+ def callback_viewport_resize(sender, app_data):
+ while self.rendering:
+ time.sleep(0.01)
+ self.need_update = False
+ self.W = app_data[0]
+ self.H = app_data[1]
+ self.cam.image_width = self.W
+ self.cam.image_height = self.H
+ self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32)
+
+ # delete and re-add the texture and image
+ dpg.delete_item("_texture")
+ dpg.delete_item("_image")
+
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+ dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window")
+ dpg.configure_item("_render_window", width=self.W, height=self.H)
+ self.need_update = True
+ dpg.set_viewport_resize_callback(callback_viewport_resize)
+
+ ### global theme ==================================================================================================
+ with dpg.theme() as theme_no_padding:
+ with dpg.theme_component(dpg.mvAll):
+ # set all padding to 0 to avoid scroll bar
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.bind_item_theme("_render_window", theme_no_padding)
+
+ ### finish setup ==================================================================================================
+ dpg.create_viewport(title='FLAME Editor', width=self.W, height=self.H, resizable=True)
+ dpg.setup_dearpygui()
+ dpg.show_viewport()
+
+ def reset_flame_param(self):
+ self.flame_param = {
+ 'shape': torch.zeros(1, self.cfg.model.n_shape),
+ 'expr': torch.zeros(1, self.cfg.model.n_expr),
+ 'rotation': torch.zeros(1, 3),
+ 'neck': torch.zeros(1, 3),
+ 'jaw': torch.zeros(1, 3),
+ 'eyes': torch.zeros(1, 6),
+ 'translation': torch.zeros(1, 3),
+ 'static_offset': torch.zeros(1, 3),
+ 'dynamic_offset': torch.zeros(1, 3),
+ }
+
+ def forward_flame(self, flame_param):
+ N = flame_param['expr'].shape[0]
+
+ self.verts, self.verts_cano = self.flame_model(
+ **flame_param,
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=True,
+ )
+
+ def prepare_camera(self):
+ @dataclass
+ class Cam:
+ FoVx = float(np.radians(self.cam.fovx))
+ FoVy = float(np.radians(self.cam.fovy))
+ image_height = self.cam.image_height
+ image_width = self.cam.image_width
+ world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
+ full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
+ camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda()
+ return Cam
+
+ def run(self):
+
+ while dpg.is_dearpygui_running():
+
+ if self.need_update:
+ self.rendering = True
+
+ with torch.no_grad():
+ # mesh
+ self.forward_flame(self.flame_param)
+ verts = self.verts.cuda()
+ faces = self.flame_model.faces.cuda()
+
+ # camera
+ RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None]
+ K = torch.from_numpy(self.cam.intrinsics).cuda()[None]
+ image_size = self.cam.image_height, self.cam.image_width
+
+ if self.render_mode == 'lbs weights':
+ v_color = self.flame_model.lbs_color.cuda()
+ else:
+ v_color = torch.ones_like(verts)
+
+ if self.selected_regions != '-':
+ vid = self.flame_model.mask.get_vid_except_region(self.selected_regions)
+ v_color[..., vid, :] *= 0.3
+
+ out_dict = self.mesh_renderer.render_v_color(verts, v_color, faces, RT, K, image_size, self.cfg.background_color)
+
+ rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H)
+ rgb_mesh = rgba_mesh[:3, :, :]
+
+ self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy()
+ self.refresh()
+
+ self.rendering = False
+ self.need_update = False
+ dpg.render_dearpygui_frame()
+
+
+if __name__ == "__main__":
+ cfg = tyro.cli(Config)
+ gui = FlameViewer(cfg)
+ gui.run()
diff --git a/vhap/flame_viewer.py b/vhap/flame_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b514162c225e3fd03fa5d19469bb06ffe3429ca0
--- /dev/null
+++ b/vhap/flame_viewer.py
@@ -0,0 +1,323 @@
+import tyro
+from dataclasses import dataclass
+from typing import Optional
+from pathlib import Path
+import time
+import dearpygui.dearpygui as dpg
+import numpy as np
+import torch
+
+from vhap.util.camera import OrbitCamera
+from vhap.model.flame import FlameHead
+from vhap.config.base import ModelConfig
+from vhap.util.render_nvdiffrast import NVDiffRenderer
+
+
+@dataclass
+class Config:
+ model: ModelConfig
+ """FLAME model configuration"""
+ param_path: Optional[Path] = None
+ """Path to the npz file for FLAME parameters"""
+ W: int = 1024
+ """GUI width"""
+ H: int = 1024
+ """GUI height"""
+ radius: float = 1
+ """default GUI camera radius from center"""
+ fovy: float = 30
+ """default GUI camera fovy"""
+ background_color: tuple[float] = (1., 1., 1.)
+ """default GUI background color"""
+ use_opengl: bool = False
+ """use OpenGL or CUDA rasterizer"""
+
+
+class FlameViewer:
+ def __init__(self, cfg: Config):
+ self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters.
+
+ # flame model
+ self.flame_model = FlameHead(cfg.model.n_shape, cfg.model.n_expr, add_teeth=True)
+
+ # viewer settings
+ self.W = cfg.W
+ self.H = cfg.H
+ self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl")
+ self.last_time_fresh = None
+ self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
+ self.need_update = True # camera moved, should reset accumulation
+
+ # buffers for mouse interaction
+ self.cursor_x = None
+ self.cursor_y = None
+ self.drag_begin_x = None
+ self.drag_begin_y = None
+ self.drag_button = None
+
+ # rendering settings
+ self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera')
+ self.num_timesteps = 1
+ self.timestep = 0
+
+ self.define_gui()
+
+ def __del__(self):
+ dpg.destroy_context()
+
+ def refresh(self):
+ dpg.set_value("_texture", self.render_buffer)
+
+ if self.last_time_fresh is not None:
+ elapsed = time.time() - self.last_time_fresh
+ fps = 1 / elapsed
+ dpg.set_value("_log_fps", f'{fps:.1f}')
+ self.last_time_fresh = time.time()
+
+ def define_gui(self):
+ dpg.create_context()
+
+ # register texture =================================================================================================
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+
+ # register window ==================================================================================================
+ # the window to display the rendered image
+ with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True):
+ dpg.add_image("_texture", width=self.W, height=self.H, tag="_image")
+
+ # control window ==================================================================================================
+ with dpg.window(label="Control", tag="_control_window", autosize=True):
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("FPS: ")
+ dpg.add_text("", tag="_log_fps")
+
+ # rendering options
+ with dpg.collapsing_header(label="Render", default_open=True):
+
+ # timestep slider and buttons
+ if self.num_timesteps != None:
+ def callback_set_current_frame(sender, app_data):
+ if sender == "_slider_timestep":
+ self.timestep = app_data
+ elif sender in ["_button_timestep_plus", "_mvKey_Right"]:
+ self.timestep = min(self.timestep + 1, self.num_timesteps - 1)
+ elif sender in ["_button_timestep_minus", "_mvKey_Left"]:
+ self.timestep = max(self.timestep - 1, 0)
+ elif sender == "_mvKey_Home":
+ self.timestep = 0
+ elif sender == "_mvKey_End":
+ self.timestep = self.num_timesteps - 1
+
+ dpg.set_value("_slider_timestep", self.timestep)
+
+ self.need_update = True
+ with dpg.group(horizontal=True):
+ dpg.add_button(label='-', tag="_button_timestep_minus", callback=callback_set_current_frame)
+ dpg.add_button(label='+', tag="_button_timestep_plus", callback=callback_set_current_frame)
+ dpg.add_slider_int(label="timestep", tag='_slider_timestep', width=162, min_value=0, max_value=self.num_timesteps - 1, format="%d", default_value=0, callback=callback_set_current_frame)
+
+ # fov slider
+ def callback_set_fovy(sender, app_data):
+ self.cam.fovy = app_data
+ self.need_update = True
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy")
+
+ def callback_reset_camera(sender, app_data):
+ self.cam.reset()
+ self.need_update = True
+ dpg.set_value("_slider_fovy", self.cam.fovy)
+
+ with dpg.group(horizontal=True):
+ dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera)
+
+
+ ### register mouse handlers ========================================================================================
+
+ def callback_mouse_move(sender, app_data):
+ self.cursor_x, self.cursor_y = app_data
+ if not dpg.is_item_focused("_render_window"):
+ return
+
+ if self.drag_begin_x is None or self.drag_begin_y is None:
+ self.drag_begin_x = self.cursor_x
+ self.drag_begin_y = self.cursor_y
+ else:
+ dx = self.cursor_x - self.drag_begin_x
+ dy = self.cursor_y - self.drag_begin_y
+
+ # button=dpg.mvMouseButton_Left
+ if self.drag_button is dpg.mvMouseButton_Left:
+ self.cam.orbit(dx, dy)
+ self.need_update = True
+ elif self.drag_button is dpg.mvMouseButton_Middle:
+ self.cam.pan(dx, dy)
+ self.need_update = True
+
+ def callback_mouse_button_down(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+ self.drag_begin_x = self.cursor_x
+ self.drag_begin_y = self.cursor_y
+ self.drag_button = app_data[0]
+
+ def callback_mouse_release(sender, app_data):
+ self.drag_begin_x = None
+ self.drag_begin_y = None
+ self.drag_button = None
+
+ self.dx_prev = None
+ self.dy_prev = None
+
+ def callback_mouse_drag(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+
+ button, dx, dy = app_data
+ if self.dx_prev is None or self.dy_prev is None:
+ ddx = dx
+ ddy = dy
+ else:
+ ddx = dx - self.dx_prev
+ ddy = dy - self.dy_prev
+
+ self.dx_prev = dx
+ self.dy_prev = dy
+
+ if ddx != 0 and ddy != 0:
+ if button is dpg.mvMouseButton_Left:
+ self.cam.orbit(ddx, ddy)
+ self.need_update = True
+ elif button is dpg.mvMouseButton_Middle:
+ self.cam.pan(ddx, ddy)
+ self.need_update = True
+
+ def callback_camera_wheel_scale(sender, app_data):
+ if not dpg.is_item_focused("_render_window"):
+ return
+ delta = app_data
+ self.cam.scale(delta)
+ self.need_update = True
+
+ with dpg.handler_registry():
+ # this registry order helps avoid false fire
+ dpg.add_mouse_release_handler(callback=callback_mouse_release)
+ # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point
+ dpg.add_mouse_move_handler(callback=callback_mouse_move)
+ dpg.add_mouse_down_handler(callback=callback_mouse_button_down)
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
+
+ # key press handlers
+ dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left')
+ dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right')
+ dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home')
+ dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End')
+
+ def callback_viewport_resize(sender, app_data):
+ while self.rendering:
+ time.sleep(0.01)
+ self.need_update = False
+ self.W = app_data[0]
+ self.H = app_data[1]
+ self.cam.image_width = self.W
+ self.cam.image_height = self.H
+ self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32)
+
+ # delete and re-add the texture and image
+ dpg.delete_item("_texture")
+ dpg.delete_item("_image")
+
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+ dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window")
+ dpg.configure_item("_render_window", width=self.W, height=self.H)
+ self.need_update = True
+ dpg.set_viewport_resize_callback(callback_viewport_resize)
+
+ ### global theme ==================================================================================================
+ with dpg.theme() as theme_no_padding:
+ with dpg.theme_component(dpg.mvAll):
+ # set all padding to 0 to avoid scroll bar
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.bind_item_theme("_render_window", theme_no_padding)
+
+ ### finish setup ==================================================================================================
+ dpg.create_viewport(title='FLAME Sequence Viewer', width=self.W, height=self.H, resizable=True)
+ dpg.setup_dearpygui()
+ dpg.show_viewport()
+
+ def forward_flame(self, flame_param):
+ N = flame_param['expr'].shape[0]
+
+ self.verts, self.verts_cano = self.flame_model(
+ flame_param['shape'][None, ...].expand(N, -1),
+ flame_param['expr'],
+ flame_param['rotation'],
+ flame_param['neck_pose'],
+ flame_param['jaw_pose'],
+ flame_param['eyes_pose'],
+ flame_param['translation'],
+ zero_centered_at_root_node=False,
+ return_landmarks=False,
+ return_verts_cano=True,
+ static_offset=flame_param['static_offset'],
+ # dynamic_offset=flame_param['dynamic_offset'],
+ )
+
+ self.num_timesteps = N
+ dpg.configure_item("_slider_timestep", max_value=self.num_timesteps - 1)
+
+ def prepare_camera(self):
+ @dataclass
+ class Cam:
+ FoVx = float(np.radians(self.cam.fovx))
+ FoVy = float(np.radians(self.cam.fovy))
+ image_height = self.cam.image_height
+ image_width = self.cam.image_width
+ world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
+ full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
+ camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda()
+ return Cam
+
+ def run(self):
+ if self.cfg.param_path is not None:
+ if self.cfg.param_path.exists():
+ self.flame_param = dict(np.load(self.cfg.param_path))
+ for k, v in self.flame_param.items():
+ if v.dtype in [np.float64, np.float32]:
+ self.flame_param[k] = torch.from_numpy(v).float()
+ self.forward_flame(self.flame_param)
+ else:
+ raise FileNotFoundError(f'{self.cfg.param_path} does not exist.')
+
+ while dpg.is_dearpygui_running():
+
+ if self.need_update:
+ self.rendering = True
+
+ with torch.no_grad():
+ RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None]
+ K = torch.from_numpy(self.cam.intrinsics).cuda()[None]
+ image_size = self.cam.image_height, self.cam.image_width
+ verts = self.verts[[self.timestep]].cuda()
+ faces = self.flame_model.faces.cuda()
+ out_dict = self.mesh_renderer.render_without_texture(verts, faces, RT, K, image_size, self.cfg.background_color)
+
+ rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H)
+ rgb_mesh = rgba_mesh[:3, :, :]
+
+ self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy()
+ self.refresh()
+
+ self.rendering = False
+ self.need_update = False
+ dpg.render_dearpygui_frame()
+
+
+if __name__ == "__main__":
+ cfg = tyro.cli(Config)
+ gui = FlameViewer(cfg)
+ gui.run()
diff --git a/vhap/generate_flame_uvmask.py b/vhap/generate_flame_uvmask.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d6e7d25607a2e575fb7f61ff25966e0b68a4ba
--- /dev/null
+++ b/vhap/generate_flame_uvmask.py
@@ -0,0 +1,81 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from typing import Literal
+import tyro
+import numpy as np
+from PIL import Image
+from pathlib import Path
+import torch
+import nvdiffrast.torch as dr
+from vhap.util.render_uvmap import render_uvmap_vtex
+from vhap.model.flame import FlameHead
+
+
+FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks"
+FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz"
+
+
+def main(
+ use_opengl: bool = False,
+ device: Literal['cuda', 'cpu'] = 'cuda',
+):
+ n_shape = 300
+ n_expr = 100
+ print("Initializing FLAME model")
+ flame_model = FlameHead(n_shape, n_expr, add_teeth=True)
+
+ flame_model = FlameHead(
+ n_shape,
+ n_expr,
+ add_teeth=True,
+ ).cuda()
+
+ faces = flame_model.faces.int().cuda()
+ verts_uv = flame_model.verts_uvs.cuda()
+ # verts_uv[:, 1] = 1 - verts_uv[:, 1]
+ faces_uv = flame_model.textures_idx.int().cuda()
+ col_idx = faces_uv
+
+ # Rasterizer context
+ glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
+
+ h, w = 2048, 2048
+ resolution = (h, w)
+
+ if not Path(FLAME_UV_MASK_FOLDER).exists():
+ Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True)
+
+ # alpha_maps = {}
+ masks = {}
+ for region, vt_mask in flame_model.mask.vt:
+ v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel
+ v_color[vt_mask] = 1
+
+ alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0]
+ alpha = alpha.flip(0)
+ # alpha_maps[region] = alpha.cpu().numpy()
+ mask = (alpha > 0.5) # to avoid overlap between hair and face
+ mask = mask.squeeze(-1).cpu().numpy()
+ masks[region] = mask # (h, w)
+
+ print(f"Saving uv mask for {region}...")
+ # rgba = mask.expand(-1, -1, 4) # (h, w, 4)
+ # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3)
+ # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4)
+ img = mask
+ img = Image.fromarray((img * 255).astype(np.uint8))
+ img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png")
+
+ print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}")
+ np.savez_compressed(FLAME_UV_MASK_NPZ, **masks)
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
\ No newline at end of file
diff --git a/vhap/model/flame.py b/vhap/model/flame.py
new file mode 100644
index 0000000000000000000000000000000000000000..1328bf7b1f173fa73f1e76f3350acfdf7e18ee25
--- /dev/null
+++ b/vhap/model/flame.py
@@ -0,0 +1,1070 @@
+# Code heavily inspired by https://github.com/HavenFeng/photometric_optimization/blob/master/models/FLAME.py.
+# Please consider citing their work if you find this code useful. The code is subject to the license available via
+# https://github.com/vchoutas/smplx/edit/master/LICENSE
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+
+from vhap.model.lbs import lbs, vertices2landmarks, blend_shapes, vertices2joints
+from vhap.util.mesh import face_vertices
+from vhap.util.log import get_logger
+from pytorch3d.io import load_obj
+from pytorch3d.structures.meshes import Meshes
+from matplotlib import cm
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pickle
+import torch.nn.functional as F
+from collections import defaultdict
+from PIL import Image
+
+logger = get_logger(__name__)
+
+# FLAME_MODEL_PATH = "asset/flame/generic_model.pkl"
+FLAME_MODEL_PATH = "pretrained_models/human_model_files/flame_vhap/flame2023.pkl"
+FLAME_MESH_PATH = "pretrained_models/human_model_files/flame_vhap/head_template_mesh.obj"
+FLAME_PARTS_PATH = "pretrained_models/human_model_files/flame_vhap/FLAME_masks.pkl"
+FLAME_LMK_PATH = "pretrained_models/human_model_files/flame_vhap/landmark_embedding_with_eyes.npy"
+FLAME_TEX_PATH = "pretrained_models/human_model_files/flame_vhap/FLAME_texture.npz"
+FLAME_PAINTED_TEX_PATH = "pretrained_models/human_model_files/flame_vhap/tex_mean_painted.png"
+FLAME_UVMASK_PATH = "pretrained_models/human_model_files/flame_vhap/uv_masks.npz"
+
+
+def to_tensor(array, dtype=torch.float32):
+ if "torch.tensor" not in str(type(array)):
+ return torch.tensor(array, dtype=dtype)
+
+
+def to_np(array, dtype=np.float32):
+ if "scipy.sparse" in str(type(array)):
+ array = array.todense()
+ return np.array(array, dtype=dtype)
+
+
+class Struct(object):
+ def __init__(self, **kwargs):
+ for key, val in kwargs.items():
+ setattr(self, key, val)
+
+
+class FlameHead(nn.Module):
+ """
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+
+ def __init__(
+ self,
+ shape_params,
+ expr_params,
+ flame_model_path=FLAME_MODEL_PATH,
+ flame_lmk_embedding_path=FLAME_LMK_PATH,
+ flame_template_mesh_path=FLAME_MESH_PATH,
+ include_mask=True,
+ include_lbs_color=False,
+ add_teeth=False,
+ connect_lip_inside=False,
+ remove_lip_inside=False,
+ disable_deformation_on_torso=False,
+ remove_torso=False,
+ face_clusters=[],
+ ):
+ super().__init__()
+
+ logger.info("Initializing FLAME mesh model...")
+
+ self.n_shape_params = shape_params
+ self.n_expr_params = expr_params
+
+ with open(flame_model_path, "rb") as f:
+ ss = pickle.load(f, encoding="latin1")
+ flame_model = Struct(**ss)
+
+ self.dtype = torch.float32
+ # The vertices of the template model
+ self.register_buffer(
+ "v_template", to_tensor(to_np(flame_model.v_template), dtype=self.dtype)
+ )
+
+ # The shape components and expression
+ shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
+ shapedirs = torch.cat(
+ [shapedirs[:, :, :shape_params], shapedirs[:, :, 300 : 300 + expr_params]],
+ 2,
+ )
+ self.register_buffer("shapedirs", shapedirs)
+
+ # The pose components
+ num_pose_basis = flame_model.posedirs.shape[-1]
+ posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
+ self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
+ #
+ self.register_buffer(
+ "J_regressor", to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)
+ )
+ parents = to_tensor(to_np(flame_model.kintree_table[0])).long()
+ parents[0] = -1
+ self.register_buffer("parents", parents)
+ self.register_buffer(
+ "lbs_weights", to_tensor(to_np(flame_model.weights), dtype=self.dtype)
+ )
+
+ # Landmark embeddings for FLAME
+ lmk_embeddings = np.load(
+ flame_lmk_embedding_path, allow_pickle=True, encoding="latin1"
+ )
+ lmk_embeddings = lmk_embeddings[()]
+ self.register_buffer(
+ "full_lmk_faces_idx",
+ torch.tensor(lmk_embeddings["full_lmk_faces_idx"], dtype=torch.long),
+ )
+ self.register_buffer(
+ "full_lmk_bary_coords",
+ torch.tensor(lmk_embeddings["full_lmk_bary_coords"], dtype=self.dtype),
+ )
+
+ neck_kin_chain = []
+ NECK_IDX = 1
+ curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
+ while curr_idx != -1:
+ neck_kin_chain.append(curr_idx)
+ curr_idx = self.parents[curr_idx]
+ self.register_buffer("neck_kin_chain", torch.stack(neck_kin_chain))
+
+ # add faces and uvs
+ verts, faces, aux = load_obj(flame_template_mesh_path, load_textures=False)
+
+ vertex_uvs = aux.verts_uvs
+ face_uvs_idx = faces.textures_idx # index into verts_uvs
+
+ # create uvcoords per face --> this is what you can use for uv map rendering
+ # range from -1 to 1 (-1, -1) = left top; (+1, +1) = right bottom
+ # pad 1 to the end
+ pad = torch.ones(vertex_uvs.shape[0], 1)
+ vertex_uvs = torch.cat([vertex_uvs, pad], dim=-1)
+ vertex_uvs = vertex_uvs * 2 - 1
+ vertex_uvs[..., 1] = -vertex_uvs[..., 1]
+
+ face_uv_coords = face_vertices(vertex_uvs[None], face_uvs_idx[None])[0]
+ self.register_buffer("face_uvcoords", face_uv_coords, persistent=False)
+ self.register_buffer("faces", faces.verts_idx, persistent=False)
+
+ self.register_buffer("verts_uvs", aux.verts_uvs, persistent=False)
+ self.register_buffer("textures_idx", faces.textures_idx, persistent=False)
+
+ if include_mask:
+ self.mask = FlameMask(
+ faces=self.faces,
+ faces_t=self.textures_idx,
+ num_verts=self.v_template.shape[0],
+ num_faces=self.faces.shape[0],
+ face_clusters=face_clusters,
+ )
+
+ if add_teeth:
+ self.add_teeth()
+
+ if connect_lip_inside:
+ self.connect_lip_inside()
+
+ if remove_lip_inside:
+ # this will change faces indices, so landmarks will be wrong if landmark embeddings are not updated
+ self.remove_lip_inside()
+
+ if remove_torso:
+ # this will change faces indices, so landmarks will be wrong if landmark embeddings are not updated
+ self.remove_torso()
+
+ if disable_deformation_on_torso:
+ self.disable_deformation_on_torso(expr_params)
+
+ # laplacian matrix
+ laplacian_matrix = Meshes(verts=[self.v_template], faces=[faces.verts_idx]).laplacian_packed().to_dense()
+ self.register_buffer("laplacian_matrix", laplacian_matrix, persistent=False)
+
+ D = torch.diag(laplacian_matrix)
+ laplacian_matrix_negate_diag = laplacian_matrix - torch.diag(D) * 2
+ self.register_buffer("laplacian_matrix_negate_diag", laplacian_matrix_negate_diag, persistent=False)
+
+ if include_lbs_color:
+ self.add_lbs_color()
+
+ def add_teeth(self):
+ # get reference vertices from lips
+ vid_lip_outside_ring_upper = self.mask.get_vid_by_region(['lip_outside_ring_upper'], keep_order=True)
+
+ vid_lip_outside_ring_lower = self.mask.get_vid_by_region(['lip_outside_ring_lower'], keep_order=True)
+
+ v_lip_upper = self.v_template[vid_lip_outside_ring_upper]
+ v_lip_lower = self.v_template[vid_lip_outside_ring_lower]
+
+ # construct vertices for teeth
+ mean_dist = (v_lip_upper - v_lip_lower).norm(dim=-1, keepdim=True).mean()
+ v_teeth_middle = (v_lip_upper + v_lip_lower) / 2
+ v_teeth_middle[:, 1] = v_teeth_middle[:, [1]].mean(dim=0, keepdim=True)
+ # v_teeth_middle[:, 2] -= mean_dist * 2.5 # how far the teeth are from the lips
+ # v_teeth_middle[:, 2] -= mean_dist * 2 # how far the teeth are from the lips
+ v_teeth_middle[:, 2] -= mean_dist * 1.5 # how far the teeth are from the lips
+
+ # upper, front
+ v_teeth_upper_edge = v_teeth_middle.clone() + torch.tensor([[0, mean_dist, 0]])*0.1
+ v_teeth_upper_root = v_teeth_upper_edge + torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # lower, front
+ v_teeth_lower_edge = v_teeth_middle.clone() - torch.tensor([[0, mean_dist, 0]])*0.1
+ # v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.2 # slightly move the lower teeth to the back
+ v_teeth_lower_edge -= torch.tensor([[0, 0, mean_dist]]) * 0.4 # slightly move the lower teeth to the back
+ v_teeth_lower_root = v_teeth_lower_edge - torch.tensor([[0, mean_dist, 0]]) * 2 # scale the height of teeth
+
+ # thickness = mean_dist * 0.5
+ thickness = mean_dist * 1.
+ # upper, back
+ v_teeth_upper_root_back = v_teeth_upper_root.clone()
+ v_teeth_upper_edge_back = v_teeth_upper_edge.clone()
+ v_teeth_upper_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_upper_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # lower, back
+ v_teeth_lower_root_back = v_teeth_lower_root.clone()
+ v_teeth_lower_edge_back = v_teeth_lower_edge.clone()
+ v_teeth_lower_root_back[:, 2] -= thickness # how thick the teeth are
+ v_teeth_lower_edge_back[:, 2] -= thickness # how thick the teeth are
+
+ # concatenate to v_template
+ num_verts_orig = self.v_template.shape[0]
+ v_teeth = torch.cat([
+ v_teeth_upper_root, # num_verts_orig + 0-14
+ v_teeth_lower_root, # num_verts_orig + 15-29
+ v_teeth_upper_edge, # num_verts_orig + 30-44
+ v_teeth_lower_edge, # num_verts_orig + 45-59
+ v_teeth_upper_root_back, # num_verts_orig + 60-74
+ v_teeth_upper_edge_back, # num_verts_orig + 75-89
+ v_teeth_lower_root_back, # num_verts_orig + 90-104
+ v_teeth_lower_edge_back, # num_verts_orig + 105-119
+ ], dim=0)
+ num_verts_teeth = v_teeth.shape[0]
+ self.v_template = torch.cat([self.v_template, v_teeth], dim=0)
+
+ vid_teeth_upper_root = torch.arange(0, 15) + num_verts_orig
+ vid_teeth_lower_root = torch.arange(15, 30) + num_verts_orig
+ vid_teeth_upper_edge = torch.arange(30, 45) + num_verts_orig
+ vid_teeth_lower_edge = torch.arange(45, 60) + num_verts_orig
+ vid_teeth_upper_root_back = torch.arange(60, 75) + num_verts_orig
+ vid_teeth_upper_edge_back = torch.arange(75, 90) + num_verts_orig
+ vid_teeth_lower_root_back = torch.arange(90, 105) + num_verts_orig
+ vid_teeth_lower_edge_back = torch.arange(105, 120) + num_verts_orig
+
+ vid_teeth_upper = torch.cat([vid_teeth_upper_root, vid_teeth_upper_edge, vid_teeth_upper_root_back, vid_teeth_upper_edge_back], dim=0)
+ vid_teeth_lower = torch.cat([vid_teeth_lower_root, vid_teeth_lower_edge, vid_teeth_lower_root_back, vid_teeth_lower_edge_back], dim=0)
+ vid_teeth = torch.cat([vid_teeth_upper, vid_teeth_lower], dim=0)
+
+ # update vertex masks
+ self.mask.v.register_buffer("teeth_upper", vid_teeth_upper)
+ self.mask.v.register_buffer("teeth_lower", vid_teeth_lower)
+ self.mask.v.register_buffer("teeth", vid_teeth)
+ self.mask.v.left_half = torch.cat([
+ self.mask.v.left_half,
+ torch.tensor([
+ 5023, 5024, 5025, 5026, 5027, 5028, 5029, 5030, 5038, 5039, 5040, 5041, 5042, 5043, 5044, 5045, 5053, 5054, 5055, 5056, 5057, 5058, 5059, 5060, 5068, 5069, 5070, 5071, 5072, 5073, 5074, 5075, 5083, 5084, 5085, 5086, 5087, 5088, 5089, 5090, 5098, 5099, 5100, 5101, 5102, 5103, 5104, 5105, 5113, 5114, 5115, 5116, 5117, 5118, 5119, 5120, 5128, 5129, 5130, 5131, 5132, 5133, 5134, 5135,
+ ])], dim=0)
+
+ self.mask.v.right_half = torch.cat([
+ self.mask.v.right_half,
+ torch.tensor([
+ 5030, 5031, 5032, 5033, 5034, 5035, 5036, 5037, 5045, 5046, 5047, 5048, 5049, 5050, 5051, 5052, 5060, 5061, 5062, 5063, 5064, 5065, 5066, 5067, 5075, 5076, 5077, 5078, 5079, 5080, 5081, 5082, 5090, 5091, 5092, 5093, 5094, 5095, 5097, 5105, 5106, 5107, 5108, 5109, 5110, 5111, 5112, 5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127, 5135, 5136, 5137, 5138, 5139, 5140, 5141, 5142,
+ ])], dim=0)
+
+ # construct uv vertices for teeth
+ u = torch.linspace(0.62, 0.38, 15)
+ v = torch.linspace(1-0.0083, 1-0.0425, 7)
+ # v = v[[0, 2, 1, 1]]
+ # v = v[[0, 3, 1, 4, 3, 2, 6, 5]]
+ v = v[[3, 2, 0, 1, 3, 4, 6, 5]] # TODO: with this order, teeth_lower is not rendered correctly in the uv space
+ uv = torch.stack(torch.meshgrid(u, v, indexing='ij'), dim=-1).permute(1, 0, 2).reshape(num_verts_teeth, 2) # (#num_teeth, 2)
+ num_verts_uv_orig = self.verts_uvs.shape[0]
+ num_verts_uv_teeth = uv.shape[0]
+ self.verts_uvs = torch.cat([self.verts_uvs, uv], dim=0)
+
+ # shapedirs copy from lips
+ self.shapedirs = torch.cat([self.shapedirs, torch.zeros_like(self.shapedirs[:num_verts_teeth])], dim=0)
+ shape_dirs_mean = (self.shapedirs[vid_lip_outside_ring_upper, :, :self.n_shape_params] + self.shapedirs[vid_lip_outside_ring_lower, :, :self.n_shape_params]) / 2
+ self.shapedirs[vid_teeth_upper_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_upper_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_root_back, :, :self.n_shape_params] = shape_dirs_mean
+ self.shapedirs[vid_teeth_lower_edge_back, :, :self.n_shape_params] = shape_dirs_mean
+
+ # posedirs set to zero
+ posedirs = self.posedirs.reshape(len(self.parents)-1, 9, num_verts_orig, 3) # (J*9, V*3) -> (J, 9, V, 3)
+ posedirs = torch.cat([posedirs, torch.zeros_like(posedirs[:, :, :num_verts_teeth])], dim=2) # (J, 9, V+num_verts_teeth, 3)
+ self.posedirs = posedirs.reshape((len(self.parents)-1)*9, (num_verts_orig+num_verts_teeth)*3) # (J*9, (V+num_verts_teeth)*3)
+
+ # J_regressor set to zero
+ self.J_regressor = torch.cat([self.J_regressor, torch.zeros_like(self.J_regressor[:, :num_verts_teeth])], dim=1) # (5, J) -> (5, J+num_verts_teeth)
+
+ # lbs_weights manually set
+ self.lbs_weights = torch.cat([self.lbs_weights, torch.zeros_like(self.lbs_weights[:num_verts_teeth])], dim=0) # (V, 5) -> (V+num_verts_teeth, 5)
+ self.lbs_weights[vid_teeth_upper, 1] += 1 # move with neck
+ self.lbs_weights[vid_teeth_lower, 2] += 1 # move with jaw
+
+ # add faces for teeth
+ f_teeth_upper = torch.tensor([
+ [0, 31, 30], #0
+ [0, 1, 31], #1
+ [1, 32, 31], #2
+ [1, 2, 32], #3
+ [2, 33, 32], #4
+ [2, 3, 33], #5
+ [3, 34, 33], #6
+ [3, 4, 34], #7
+ [4, 35, 34], #8
+ [4, 5, 35], #9
+ [5, 36, 35], #10
+ [5, 6, 36], #11
+ [6, 37, 36], #12
+ [6, 7, 37], #13
+ [7, 8, 37], #14
+ [8, 38, 37], #15
+ [8, 9, 38], #16
+ [9, 39, 38], #17
+ [9, 10, 39], #18
+ [10, 40, 39], #19
+ [10, 11, 40], #20
+ [11, 41, 40], #21
+ [11, 12, 41], #22
+ [12, 42, 41], #23
+ [12, 13, 42], #24
+ [13, 43, 42], #25
+ [13, 14, 43], #26
+ [14, 44, 43], #27
+ [60, 75, 76], # 56
+ [60, 76, 61], # 57
+ [61, 76, 77], # 58
+ [61, 77, 62], # 59
+ [62, 77, 78], # 60
+ [62, 78, 63], # 61
+ [63, 78, 79], # 62
+ [63, 79, 64], # 63
+ [64, 79, 80], # 64
+ [64, 80, 65], # 65
+ [65, 80, 81], # 66
+ [65, 81, 66], # 67
+ [66, 81, 82], # 68
+ [66, 82, 67], # 69
+ [67, 82, 68], # 70
+ [68, 82, 83], # 71
+ [68, 83, 69], # 72
+ [69, 83, 84], # 73
+ [69, 84, 70], # 74
+ [70, 84, 85], # 75
+ [70, 85, 71], # 76
+ [71, 85, 86], # 77
+ [71, 86, 72], # 78
+ [72, 86, 87], # 79
+ [72, 87, 73], # 80
+ [73, 87, 88], # 81
+ [73, 88, 74], # 82
+ [74, 88, 89], # 83
+ [75, 30, 76], # 84
+ [76, 30, 31], # 85
+ [76, 31, 77], # 86
+ [77, 31, 32], # 87
+ [77, 32, 78], # 88
+ [78, 32, 33], # 89
+ [78, 33, 79], # 90
+ [79, 33, 34], # 91
+ [79, 34, 80], # 92
+ [80, 34, 35], # 93
+ [80, 35, 81], # 94
+ [81, 35, 36], # 95
+ [81, 36, 82], # 96
+ [82, 36, 37], # 97
+ [82, 37, 38], # 98
+ [82, 38, 83], # 99
+ [83, 38, 39], # 100
+ [83, 39, 84], # 101
+ [84, 39, 40], # 102
+ [84, 40, 85], # 103
+ [85, 40, 41], # 104
+ [85, 41, 86], # 105
+ [86, 41, 42], # 106
+ [86, 42, 87], # 107
+ [87, 42, 43], # 108
+ [87, 43, 88], # 109
+ [88, 43, 44], # 110
+ [88, 44, 89], # 111
+ ])
+ f_teeth_lower = torch.tensor([
+ [45, 46, 15], # 28
+ [46, 16, 15], # 29
+ [46, 47, 16], # 30
+ [47, 17, 16], # 31
+ [47, 48, 17], # 32
+ [48, 18, 17], # 33
+ [48, 49, 18], # 34
+ [49, 19, 18], # 35
+ [49, 50, 19], # 36
+ [50, 20, 19], # 37
+ [50, 51, 20], # 38
+ [51, 21, 20], # 39
+ [51, 52, 21], # 40
+ [52, 22, 21], # 41
+ [52, 23, 22], # 42
+ [52, 53, 23], # 43
+ [53, 24, 23], # 44
+ [53, 54, 24], # 45
+ [54, 25, 24], # 46
+ [54, 55, 25], # 47
+ [55, 26, 25], # 48
+ [55, 56, 26], # 49
+ [56, 27, 26], # 50
+ [56, 57, 27], # 51
+ [57, 28, 27], # 52
+ [57, 58, 28], # 53
+ [58, 29, 28], # 54
+ [58, 59, 29], # 55
+ [90, 106, 105], # 112
+ [90, 91, 106], # 113
+ [91, 107, 106], # 114
+ [91, 92, 107], # 115
+ [92, 108, 107], # 116
+ [92, 93, 108], # 117
+ [93, 109, 108], # 118
+ [93, 94, 109], # 119
+ [94, 110, 109], # 120
+ [94, 95, 110], # 121
+ [95, 111, 110], # 122
+ [95, 96, 111], # 123
+ [96, 112, 111], # 124
+ [96, 97, 112], # 125
+ [97, 98, 112], # 126
+ [98, 113, 112], # 127
+ [98, 99, 113], # 128
+ [99, 114, 113], # 129
+ [99, 100, 114], # 130
+ [100, 115, 114], # 131
+ [100, 101, 115], # 132
+ [101, 116, 115], # 133
+ [101, 102, 116], # 134
+ [102, 117, 116], # 135
+ [102, 103, 117], # 136
+ [103, 118, 117], # 137
+ [103, 104, 118], # 138
+ [104, 119, 118], # 139
+ [105, 106, 45], # 140
+ [106, 46, 45], # 141
+ [106, 107, 46], # 142
+ [107, 47, 46], # 143
+ [107, 108, 47], # 144
+ [108, 48, 47], # 145
+ [108, 109, 48], # 146
+ [109, 49, 48], # 147
+ [109, 110, 49], # 148
+ [110, 50, 49], # 149
+ [110, 111, 50], # 150
+ [111, 51, 50], # 151
+ [111, 112, 51], # 152
+ [112, 52, 51], # 153
+ [112, 53, 52], # 154
+ [112, 113, 53], # 155
+ [113, 54, 53], # 156
+ [113, 114, 54], # 157
+ [114, 55, 54], # 158
+ [114, 115, 55], # 159
+ [115, 56, 55], # 160
+ [115, 116, 56], # 161
+ [116, 57, 56], # 162
+ [116, 117, 57], # 163
+ [117, 58, 57], # 164
+ [117, 118, 58], # 165
+ [118, 59, 58], # 166
+ [118, 119, 59], # 167
+ ])
+ self.faces = torch.cat([self.faces, f_teeth_upper+num_verts_orig, f_teeth_lower+num_verts_orig], dim=0)
+ self.textures_idx = torch.cat([self.textures_idx, f_teeth_upper+num_verts_uv_orig, f_teeth_lower+num_verts_uv_orig], dim=0)
+
+ self.mask.num_verts = self.v_template.shape[0]
+ self.mask.update(self.faces, self.textures_idx)
+
+
+ def connect_lip_inside(self):
+ f_lip_connect = torch.tensor([
+ [1594, 1595, 1572], #0
+ [1595, 1746, 1572], #1
+ [1572, 1746, 1573], #2
+ [1746, 1747, 1573], #3
+ [1573, 1747, 1860], #4
+ [1747, 1742, 1860], #5
+ [1860, 1742, 1862], #6
+ [1742, 1739, 1862], #7
+ [1862, 1739, 1830], #8
+ [1739, 1665, 1830], #9
+ [1830, 1665, 1835], #10
+ [1665, 1666, 1835], #11
+ [1835, 1666, 1852], #12
+ [1666, 3514, 1852], #13
+ [1852, 3514, 3497], #14
+ [3497, 3514, 2941], #15
+ [3514, 2783, 2941], #16
+ [2941, 2783, 2933], #17
+ [2783, 2782, 2933], #18
+ [2933, 2782, 2930], #19
+ [2782, 2854, 2930], #20
+ [2930, 2854, 2945], #21
+ [2854, 2857, 2945], #22
+ [2945, 2857, 2943], #23
+ [2857, 2862, 2943], #24
+ [2943, 2862, 2709], #25
+ [2862, 2861, 2709], #26
+ [2709, 2861, 2708], #27
+ [2861, 2731, 2708], #28
+ [2731, 2730, 2708], #29
+ ])
+ self.faces = torch.cat([self.faces, f_lip_connect], dim=0)
+
+ self.mask.update(self.faces)
+
+ def remove_lip_inside(self):
+ fid = self.mask.get_fid_except_region(['lip_inside'])
+ self.faces = self.faces[fid]
+ self.textures_idx = self.textures_idx[fid]
+ self.mask.update(self.faces, self.textures_idx)
+
+ def remove_torso(self):
+ fid = self.mask.get_fid_except_region(['boundary'])
+ self.faces = self.faces[fid]
+ # self.textures_idx = self.textures_idx[fid] # TODO: have to update textures_idx for connect_lip_inside before enabling this
+ self.mask.update(self.faces, self.textures_idx)
+
+ def disable_deformation_on_torso(self, n_expr):
+ vid = self.mask.get_vid_by_region(['boundary', 'neck_lower'])
+ self.shapedirs[vid, -n_expr:] = 0
+
+ vid = self.mask.get_vid_by_region(['boundary'])
+ self.lbs_weights[vid, -3:] = 0
+
+ def add_lbs_color(self):
+ num_joints = self.lbs_weights.shape[1]
+ color_indices = np.array(range(num_joints))
+ cmap = cm.get_cmap("Set1")
+ colors = cmap(color_indices)[:, :3] # (num_joints, 3)
+ lbs_color = self.lbs_weights @ colors
+ self.register_buffer("lbs_color", lbs_color.float(), persistent=False)
+
+ def forward(
+ self,
+ shape,
+ expr,
+ rotation,
+ neck,
+ jaw,
+ eyes,
+ translation,
+ zero_centered_at_root_node=False, # otherwise, zero centered at the face
+ return_landmarks=True,
+ return_verts_cano=False,
+ static_offset=None,
+ dynamic_offset=None,
+ ):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = shape.shape[0]
+
+ betas = torch.cat([shape, expr], dim=1)
+ full_pose = torch.cat([rotation, neck, jaw, eyes], dim=1)
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Add shape contribution
+ v_shaped = template_vertices + blend_shapes(betas, self.shapedirs)
+
+ # Add personal offsets
+ if static_offset is not None:
+ v_shaped += static_offset
+ if dynamic_offset is not None:
+ v_shaped += dynamic_offset
+
+ vertices, J, mat_rot = lbs(
+ full_pose,
+ v_shaped,
+ self.posedirs,
+ self.J_regressor,
+ self.parents,
+ self.lbs_weights,
+ dtype=self.dtype,
+ )
+
+ if zero_centered_at_root_node:
+ vertices = vertices - J[:, [0]]
+ J = J - J[:, [0]]
+
+ vertices = vertices + translation[:, None, :]
+ J = J + translation[:, None, :]
+
+ ret_vals = [vertices]
+
+ if return_verts_cano:
+ ret_vals.append(v_shaped)
+
+ # compute landmarks if desired
+ if return_landmarks:
+ bz = vertices.shape[0]
+ landmarks = vertices2landmarks(
+ vertices,
+ self.faces,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1),
+ )
+ ret_vals.append(landmarks)
+
+ if len(ret_vals) > 1:
+ return ret_vals
+ else:
+ return ret_vals[0]
+
+
+class FlameTexPainted(nn.Module):
+ def __init__(self, tex_size=512, painted_tex_path=FLAME_PAINTED_TEX_PATH):
+ super().__init__()
+ logger.info("Initializing FLAME painted texture model...")
+ self.tex_size = tex_size
+
+ tex_painted = torch.tensor(np.array(Image.open(painted_tex_path))[:, :, :3]) / 255
+ tex_painted = tex_painted[None, ...].permute(0, 3, 1, 2)
+ if tex_painted.shape[-1] != self.tex_size or tex_painted.shape[-2] != self.tex_size:
+ tex_painted = F.interpolate(tex_painted, [self.tex_size, self.tex_size])
+ self.register_buffer("tex_painted", tex_painted)
+
+ def forward(self):
+ return self.tex_painted
+
+
+class FlameTexPCA(nn.Module):
+ def __init__(self, tex_params, tex_size=512, tex_space_path=FLAME_TEX_PATH):
+ super().__init__()
+ logger.info("Initializing FLAME PCA texture model...")
+ self.tex_size = tex_size
+ tex_params = tex_params
+ tex_space = np.load(tex_space_path)
+ texture_mean = tex_space["mean"].reshape(1, -1)
+ texture_basis = tex_space["tex_dir"].reshape(-1, 200)
+ texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
+ texture_basis = torch.from_numpy(texture_basis[:, :tex_params]).float()[
+ None, ...
+ ]
+ self.register_buffer("texture_mean", texture_mean)
+ self.register_buffer("texture_basis", texture_basis)
+
+ def forward(self, texcode):
+ texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1)
+ texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2)
+ texture = F.interpolate(texture, [self.tex_size, self.tex_size])
+ texture = texture[:, [2, 1, 0], :, :]
+ texture = texture / 255.0
+ return texture.clamp(0, 1)
+
+
+class BufferContainer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def __repr__(self):
+ main_str = super().__repr__() + '\n'
+ for name, buf in self.named_buffers():
+ main_str += f' {name:20}\t{buf.shape}\t{buf.dtype}\n'
+ return main_str
+
+ def __iter__(self):
+ for name, buf in self.named_buffers():
+ yield name, buf
+
+ def keys(self):
+ return [name for name, buf in self.named_buffers()]
+
+ def items(self):
+ return [(name, buf) for name, buf in self.named_buffers()]
+
+
+class FlameMask(nn.Module):
+ def __init__(
+ self,
+ flame_parts_path=FLAME_PARTS_PATH,
+ faces=None,
+ faces_t=None,
+ num_verts=5023,
+ num_faces=9976,
+ face_clusters=[],
+ ):
+ super().__init__()
+ self.faces = faces
+ self.faces_t = faces_t
+ self.face_clusters = face_clusters
+ self.num_verts = num_verts
+ if faces is not None:
+ self.num_faces = faces.shape[0]
+ else:
+ self.num_faces = num_faces
+
+ self.process_vertex_mask(flame_parts_path)
+
+ if self.faces is not None:
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def update(self, faces=None, faces_t=None, face_clusters=None):
+ """Update the faces properties when vertex masks are changed"""
+ if faces is not None:
+ self.faces = faces
+ self.num_faces = faces.shape[0]
+ if faces_t is not None:
+ self.faces_t = faces_t
+ if face_clusters is not None:
+ self.face_clusters = face_clusters
+
+ self.construct_vid_table()
+ self.process_face_mask(self.faces)
+ self.process_face_clusters(self.face_clusters)
+ if self.faces_t is not None:
+ self.process_vt_mask(self.faces, self.faces_t)
+
+ def process_vertex_mask(self, flame_parts_path):
+ """Load the vertex masks from the FLAME model and add custom masks"""
+ logger.info("Processing vertex masks for FLAME...")
+
+ part_masks = np.load(flame_parts_path, allow_pickle=True, encoding="latin1")
+ """ Available part masks from the FLAME model:
+ face, neck, scalp, boundary, right_eyeball, left_eyeball,
+ right_ear, left_ear, forehead, eye_region, nose, lips,
+ right_eye_region, left_eye_region.
+ """
+
+ self.v = BufferContainer()
+ for k, v_mask in part_masks.items():
+ self.v.register_buffer(k, torch.tensor(v_mask, dtype=torch.long))
+
+ self.create_custom_mask()
+
+ def create_custom_mask(self):
+ """Add some cutom masks based on the original FLAME masks"""
+
+ self.v.register_buffer("neck_left_point", torch.tensor([3193]))
+ self.v.register_buffer("neck_right_point", torch.tensor([3296]))
+ self.v.register_buffer("front_middle_bottom_point_boundary", torch.tensor([3285]))
+ self.v.register_buffer("back_middle_bottom_point_boundary", torch.tensor([3248]))
+
+ self.v.register_buffer(
+ "neck_top",
+ torch.tensor([
+ 10, 11, 111, 112, 784, 795, 1325, 1901, 2115, 2162, 2251, 2254, 2483, 2979, 3142, 3174, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_upper",
+ torch.tensor([
+ 1595, 1746, 1747, 1742, 1739, 1665, 1666, 3514, 2783, 2782, 2854, 2857, 2862, 2861, 2731
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_ring_lower",
+ torch.tensor([
+ 1572, 1573, 1860, 1862, 1830, 1835, 1852, 3497, 2941, 2933, 2930, 2945, 2943, 2709, 2708
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_upper",
+ torch.tensor([
+ 1713, 1715, 1716, 1735, 1696, 1694, 1657, 3543, 2774, 2811, 2813, 2850, 2833, 2832, 2830
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_outside_ring_lower",
+ torch.tensor([
+ 1576, 1577, 1773, 1774, 1795, 1802, 1865, 3503, 2948, 2905, 2898, 2881, 2880, 2713, 2712
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_upper",
+ torch.tensor([
+ 1588, 1589, 1590, 1591, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1724, 1725, 1739, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 2724, 2725, 2726, 2727, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2841, 2842, 2854, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 3514, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside_lower",
+ torch.tensor([
+ 1572, 1573, 1592, 1593, 1764, 1765, 1779, 1780, 1781, 1830, 1831, 1832, 1835, 1846, 1847, 1851, 1852, 1854, 1860, 1861, 1862, 2708, 2709, 2728, 2729, 2872, 2873, 2886, 2887, 2888, 2930, 2931, 2932, 2933, 2935, 2936, 2940, 2941, 2942, 2943, 2944, 2945, 3497, 3500, 3512,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lip_inside",
+ torch.tensor([
+ 1572, 1573, 1580, 1581, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1718, 1719, 1722, 1724, 1725, 1728, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1764, 1765, 1777, 1778, 1779, 1780, 1781, 1782, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1851, 1852, 1854, 1860, 1861, 1862, 2708, 2709, 2716, 2717, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2835, 2836, 2839, 2841, 2842, 2843, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2872, 2873, 2884, 2885, 2886, 2887, 2888, 2889, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2940, 2941, 2942, 2943, 2944, 2945, 3497, 3500, 3512, 3513, 3514, 3533, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_upper",
+ torch.tensor([
+ 10, 11, 12, 13, 14, 15, 111, 112, 219, 220, 221, 222, 372, 373, 374, 375, 462, 463, 496, 497, 552, 553, 558, 559, 563, 564, 649, 650, 736, 737, 784, 795, 1210, 1211, 1212, 1213, 1325, 1326, 1359, 1360, 1386, 1726, 1727, 1759, 1790, 1886, 1898, 1901, 1931, 1932, 1933, 1934, 1940, 1941, 1948, 1949, 2036, 2115, 2149, 2150, 2151, 2162, 2218, 2219, 2251, 2254, 2483, 2484, 2531, 2870, 2893, 2964, 2976, 2979, 3012, 3013, 3142, 3174, 3184, 3185, 3186, 3187, 3188, 3189, 3193, 3194, 3196, 3199, 3200, 3202, 3203, 3206, 3209, 3281, 3282, 3286, 3291, 3292, 3296, 3297, 3299, 3302, 3303, 3305, 3306, 3309, 3312, 3376, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460, 3461, 3462, 3463, 3494, 3496, 3544, 3562, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3695, 3697, 3698, 3701, 3703, 3707, 3709, 3713,
+ ])
+ )
+
+ self.v.register_buffer(
+ "neck_lower",
+ torch.tensor([
+ 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3220, 3222, 3223, 3231, 3232, 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3250, 3251, 3253, 3254, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3286, 3288, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3323, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3352, 3353, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3376, 3378,
+ ])
+ )
+
+ # As a subset of "boundary", "bottomline" only contains vertices on the edge
+ self.v.register_buffer(
+ "bottomline",
+ torch.tensor([
+ 3218, 3219, 3226, 3272, 3273, 3229, 3228, 3261, 3260, 3248, 3359, 3360, 3329, 3330, 3372, 3371, 3327, 3322, 3321, 3355, 3354, 3356, 3357, 3379, 3285, 3289, 3258, 3257, 3255, 3256
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_iris",
+ torch.tensor([
+ 3931, 3932, 3933, 3935, 3936, 3937, 3939, 3940, 3941, 3943, 3944, 3945, 3947, 3948, 3949, 3951, 3952, 3953, 3955, 3956, 3957, 3959, 3960, 3961, 3963, 3964, 3965, 3967, 3968, 3969, 3971, 3972, 3973, 3975, 3976, 3977, 3979, 3980, 3981, 3983, 3984, 3985, 3987, 3988, 3989, 3991, 3992, 3993, 3995, 3996, 3997, 3999, 4000, 4001, 4003, 4004, 4005, 4007, 4008, 4009, 4011, 4012, 4013, 4015, 4016, 4017, 4019, 4020, 4021, 4023, 4024, 4025, 4027, 4028, 4029, 4031, 4032, 4033, 4035, 4036, 4037, 4039, 4040, 4041, 4043, 4044, 4045, 4047, 4048, 4049, 4051, 4052, 4053, 4054, 4056, 4057, 4058,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_iris",
+ torch.tensor([
+ 4477, 4478, 4479, 4481, 4482, 4483, 4485, 4486, 4487, 4489, 4490, 4491, 4493, 4494, 4495, 4497, 4498, 4499, 4501, 4502, 4503, 4505, 4506, 4507, 4509, 4510, 4511, 4513, 4514, 4515, 4517, 4518, 4519, 4521, 4522, 4523, 4525, 4526, 4527, 4529, 4530, 4531, 4533, 4534, 4535, 4537, 4538, 4539, 4541, 4542, 4543, 4545, 4546, 4547, 4549, 4550, 4551, 4553, 4554, 4555, 4557, 4558, 4559, 4561, 4562, 4563, 4565, 4566, 4567, 4569, 4570, 4571, 4573, 4574, 4575, 4577, 4578, 4579, 4581, 4582, 4583, 4585, 4586, 4587, 4589, 4590, 4591, 4593, 4594, 4595, 4597, 4598, 4599, 4600, 4602, 4603, 4604,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_eyelid", # 30 vertices
+ torch.tensor([
+ 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 841, 842, 848, 864, 865, 877, 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 903, 904, 905, 922, 923, 924, 926, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 958, 959, 991, 992, 993, 994, 995, 999, 1000, 1003, 1006, 1008, 1011, 1023, 1033, 1034, 1045, 1046, 1059, 1060, 1061, 1062, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1125, 1126, 1132, 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1170, 1175, 1182, 1183, 1194, 1195, 1200, 1201, 1202, 1216, 1217, 1218, 1224, 1227, 1230, 1232, 1233, 1243, 1244, 1283, 1289, 1292, 1293, 1294, 1320, 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 3827, 3832, 3833, 3835, 3853, 3855, 3856, 3861,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_eyelid", # 30 vertices
+ torch.tensor([
+ 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2282, 2283, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2303, 2304, 2305, 2312, 2313, 2314, 2315, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2364, 2365, 2367, 2369, 2381, 2382, 2383, 2386, 2387, 2388, 2389, 2390, 2391, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2411, 2412, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2436, 2437, 2440, 2441, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2457, 2460, 2461, 2462, 2465, 2466, 2467, 2470, 2471, 2472, 2473, 2478, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 3619, 3631, 3632, 3638, 3687, 3689, 3690, 3700,
+ ])
+ )
+
+ self.v.register_buffer(
+ "lips_tight", # 30 vertices
+ torch.tensor([
+ 1572, 1573, 1578, 1580, 1581, 1582, 1583, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1750, 1751, 1758, 1764, 1765, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1787, 1788, 1789, 1791, 1792, 1793, 1794, 1795, 1802, 1803, 1804, 1826, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860, 1861, 1862, 1865, 2708, 2709, 2714, 2716, 2717, 2718, 2719, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2865, 2866, 2869, 2872, 2873, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2894, 2895, 2896, 2897, 2898, 2905, 2906, 2907, 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2948, 3497, 3500, 3503, 3504, 3506, 3509, 3512, 3513, 3514, 3531, 3533, 3546, 3547, 3549,
+ ])
+ )
+
+ self.v.register_buffer(
+ "left_half",
+ torch.tensor([
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 530, 531, 532, 533, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 588, 589, 590, 591, 592, 593, 594, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 638, 639, 644, 645, 646, 647, 648, 649, 650, 667, 668, 669, 670, 671, 672, 673, 674, 679, 680, 681, 682, 683, 688, 691, 692, 693, 694, 695, 696, 697, 702, 703, 704, 705, 706, 707, 708, 709, 712, 713, 714, 715, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 745, 746, 747, 748, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 783, 784, 785, 786, 795, 796, 797, 798, 799, 802, 803, 804, 805, 806, 807, 808, 809, 814, 815, 816, 821, 822, 823, 824, 825, 826, 827, 828, 829, 837, 838, 840, 841, 842, 846, 847, 848, 864, 865, 877, 878, 879, 880, 881, 882, 883, 884, 885, 896, 897, 898, 899, 902, 903, 904, 905, 906, 907, 908, 909, 918, 919, 922, 923, 924, 926, 927, 928, 929, 939, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 977, 978, 979, 980, 985, 986, 991, 992, 993, 994, 995, 999, 1000, 1001, 1002, 1003, 1006, 1007, 1008, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1033, 1034, 1043, 1044, 1045, 1046, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1068, 1075, 1085, 1086, 1087, 1088, 1092, 1093, 1096, 1101, 1108, 1113, 1114, 1115, 1116, 1117, 1125, 1126, 1127, 1128, 1129, 1132, 1134, 1135, 1142, 1143, 1144, 1146, 1147, 1150, 1151, 1152, 1153, 1154, 1155, 1161, 1162, 1163, 1164, 1168, 1169, 1170, 1175, 1176, 1181, 1182, 1183, 1184, 1189, 1190, 1193, 1194, 1195, 1200, 1201, 1202, 1216, 1217, 1218, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1232, 1233, 1241, 1242, 1243, 1244, 1283, 1284, 1287, 1289, 1292, 1293, 1294, 1298, 1299, 1308, 1309, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1329, 1331, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1617, 1618, 1623, 1624, 1625, 1626, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1756, 1757, 1758, 1759, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1823, 1824, 1825, 1826, 1827, 1830, 1831, 1832, 1835, 1836, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1854, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1914, 1915, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1938, 1939, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2004, 2009, 2010, 2011, 2012, 2021, 2022, 2023, 2024, 2025, 2026, 2029, 2030, 2033, 2034, 2035, 2036, 2037, 2038, 2039, 2040, 2041, 2042, 2043, 2044, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056, 2057, 2058, 2059, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067, 2068, 2069, 2070, 2071, 2072, 2073, 2074, 2075, 2076, 2077, 2078, 2079, 2080, 2081, 2082, 2083, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2099, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108, 2109, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118, 2119, 2120, 2121, 2122, 2125, 2126, 2127, 2134, 2135, 2136, 2137, 2138, 2139, 2140, 2141, 2142, 2143, 2148, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158, 2159, 2160, 2161, 2162, 2163, 2164, 2169, 2170, 2171, 2172, 2173, 2174, 2175, 3186, 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3216, 3217, 3218, 3219, 3220, 3221, 3222, 3223, 3224, 3225, 3226, 3227, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3248, 3249, 3250, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3261, 3262, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3271, 3272, 3273, 3274, 3275, 3276, 3277, 3278, 3279, 3280, 3281, 3282, 3283, 3284, 3285, 3286, 3287, 3288, 3289, 3290, 3399, 3400, 3401, 3404, 3414, 3442, 3457, 3459, 3461, 3463, 3487, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581, 3582, 3583, 3584, 3587, 3588, 3593, 3594, 3595, 3596, 3598, 3599, 3600, 3601, 3604, 3605, 3611, 3614, 3623, 3624, 3625, 3626, 3628, 3629, 3630, 3634, 3635, 3636, 3637, 3643, 3644, 3646, 3649, 3650, 3652, 3653, 3654, 3655, 3656, 3658, 3659, 3660, 3662, 3663, 3664, 3665, 3666, 3667, 3668, 3670, 3671, 3672, 3673, 3676, 3677, 3678, 3679, 3680, 3681, 3685, 3691, 3693, 3695, 3697, 3698, 3701, 3703, 3704, 3707, 3709, 3713, 3714, 3715, 3716, 3717, 3722, 3724, 3725, 3726, 3727, 3728, 3730, 3734, 3737, 3738, 3739, 3740, 3742, 3745, 3752, 3753, 3754, 3756, 3757, 3760, 3761, 3762, 3769, 3771, 3772, 3785, 3786, 3790, 3801, 3807, 3808, 3809, 3810, 3811, 3812, 3813, 3814, 3815, 3816, 3817, 3818, 3819, 3820, 3821, 3822, 3823, 3824, 3825, 3826, 3827, 3828, 3829, 3830, 3831, 3832, 3833, 3834, 3835, 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 3844, 3845, 3846, 3847, 3848, 3849, 3850, 3851, 3852, 3853, 3854, 3855, 3856, 3857, 3858, 3859, 3860, 3861, 3862, 3863, 3864, 3865, 3866, 3867, 3868, 3869, 3870, 3871, 3872, 3873, 3874, 3875, 3876, 3877, 3878, 3879, 3880, 3881, 3882, 3883, 3884, 3885, 3886, 3887, 3888, 3889, 3890, 3891, 3892, 3893, 3894, 3895, 3896, 3897, 3898, 3899, 3900, 3901, 3902, 3903, 3904, 3905, 3906, 3907, 3908, 3909, 3910, 3911, 3912, 3913, 3914, 3915, 3916, 3917, 3918, 3919, 3920, 3921, 3922, 3923, 3924, 3925, 3926, 3927, 3928, 3929, 3931, 3932, 3933, 3934, 3935, 3936, 3937, 3938, 3939, 3940, 3941, 3942, 3943, 3944, 3945, 3946, 3947, 3948, 3949, 3950, 3951, 3952, 3953, 3954, 3955, 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3965, 3966, 3967, 3968, 3969, 3970, 3971, 3972, 3973, 3974, 3975, 3976, 3977, 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987, 3988, 3989, 3990, 3991, 3992, 3993, 3994, 3995, 3996, 3997, 3998, 3999, 4000, 4001, 4002, 4003, 4004, 4005, 4006, 4007, 4008, 4009, 4010, 4011, 4012, 4013, 4014, 4015, 4016, 4017, 4018, 4019, 4020, 4021, 4022, 4023, 4024, 4025, 4026, 4027, 4028, 4029, 4030, 4031, 4032, 4033, 4034, 4035, 4036, 4037, 4038, 4039, 4040, 4041, 4042, 4043, 4044, 4045, 4046, 4047, 4048, 4049, 4050, 4051, 4052, 4053, 4054, 4055, 4056, 4057, 4058, 4059, 4060, 4061, 4062, 4063, 4064, 4065, 4066, 4067, 4068, 4069, 4070, 4071, 4072, 4073, 4074, 4075, 4076, 4077, 4078, 4079, 4080, 4081, 4082, 4083, 4084, 4085, 4086, 4087, 4088, 4089, 4090, 4091, 4092, 4093, 4094, 4095, 4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103, 4104, 4105, 4106, 4107, 4108, 4109, 4110, 4111, 4112, 4113, 4114, 4115, 4116, 4117, 4118, 4119, 4120, 4121, 4122, 4123, 4124, 4125, 4126, 4127, 4128, 4129, 4130, 4131, 4132, 4133, 4134, 4135, 4136, 4137, 4138, 4139, 4140, 4141, 4142, 4143, 4144, 4145, 4146, 4147, 4148, 4149, 4150, 4151, 4152, 4153, 4154, 4155, 4156, 4157, 4158, 4159, 4160, 4161, 4162, 4163, 4164, 4165, 4166, 4167, 4168, 4169, 4170, 4171, 4172, 4173, 4174, 4175, 4176, 4177, 4178, 4179, 4180, 4181, 4182, 4183, 4184, 4185, 4186, 4187, 4188, 4189, 4190, 4191, 4192, 4193, 4194, 4195, 4196, 4197, 4198, 4199, 4200, 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4210, 4211, 4212, 4213, 4214, 4215, 4216, 4217, 4218, 4219, 4220, 4221, 4222, 4223, 4224, 4225, 4226, 4227, 4228, 4229, 4230, 4231, 4232, 4233, 4234, 4235, 4236, 4237, 4238, 4239, 4240, 4241, 4242, 4243, 4244, 4245, 4246, 4247, 4248, 4249, 4250, 4251, 4252, 4253, 4254, 4255, 4256, 4257, 4258, 4259, 4260, 4261, 4262, 4263, 4264, 4265, 4266, 4267, 4268, 4269, 4270, 4271, 4272, 4273, 4274, 4275, 4276, 4277, 4278, 4279, 4280, 4281, 4282, 4283, 4284, 4285, 4286, 4287, 4288, 4289, 4290, 4291, 4292, 4293, 4294, 4295, 4296, 4297, 4298, 4299, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4312, 4313, 4314, 4315, 4316, 4317, 4318, 4319, 4320, 4321, 4322, 4323, 4324, 4325, 4326, 4327, 4328, 4329, 4330, 4331, 4332, 4333, 4334, 4335, 4336, 4337, 4338, 4339, 4340, 4341, 4342, 4343, 4344, 4345, 4346, 4347, 4348, 4349, 4350, 4351, 4352, 4353, 4354, 4355, 4356, 4357, 4358, 4359, 4360, 4361, 4362, 4363, 4364, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4377, 4378, 4379, 4380, 4381, 4382, 4383, 4384, 4385, 4386, 4387, 4388, 4389, 4390, 4391, 4392, 4393, 4394, 4395, 4396, 4397, 4398, 4399, 4400, 4401, 4402, 4403, 4404, 4405, 4406, 4407, 4408, 4409, 4410, 4411, 4412, 4413, 4414, 4415, 4416, 4417, 4418, 4419, 4420, 4421, 4422, 4423, 4424, 4425, 4426, 4427, 4428, 4429, 4430, 4431, 4432, 4433, 4434, 4435, 4436, 4437, 4438, 4439, 4440, 4441, 4442, 4443, 4444, 4445, 4446, 4447, 4448, 4449, 4450, 4451, 4452, 4453, 4454, 4455, 4456, 4457, 4458, 4459, 4460, 4461, 4462, 4463, 4464, 4465, 4466, 4467, 4468, 4469, 4470, 4471, 4472, 4473, 4474, 4475, 4476,
+ ])
+ )
+
+ self.v.register_buffer(
+ "right_half",
+ torch.tensor([
+ 19, 20, 21, 22, 23, 24, 25, 26, 109, 110, 111, 112, 219, 220, 221, 222, 335, 336, 337, 338, 522, 523, 524, 525, 526, 527, 528, 529, 534, 535, 536, 537, 554, 555, 556, 557, 584, 585, 586, 587, 595, 596, 597, 598, 599, 600, 601, 602, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 634, 635, 636, 637, 640, 641, 642, 643, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 675, 676, 677, 678, 684, 685, 686, 687, 689, 690, 698, 699, 700, 701, 710, 711, 716, 717, 718, 719, 720, 721, 722, 741, 742, 743, 744, 749, 750, 751, 752, 776, 777, 778, 779, 780, 781, 782, 787, 788, 789, 790, 791, 792, 793, 794, 800, 801, 810, 811, 812, 813, 817, 818, 819, 820, 830, 831, 832, 833, 834, 835, 836, 839, 843, 844, 845, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 900, 901, 910, 911, 912, 913, 914, 915, 916, 917, 920, 921, 925, 930, 931, 932, 933, 934, 935, 936, 937, 938, 940, 941, 956, 957, 973, 974, 975, 976, 981, 982, 983, 984, 987, 988, 989, 990, 996, 997, 998, 1004, 1005, 1009, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1066, 1067, 1069, 1070, 1071, 1072, 1073, 1074, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1089, 1090, 1091, 1094, 1095, 1097, 1098, 1099, 1100, 1102, 1103, 1104, 1105, 1106, 1107, 1109, 1110, 1111, 1112, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1130, 1131, 1133, 1136, 1137, 1138, 1139, 1140, 1141, 1145, 1148, 1149, 1156, 1157, 1158, 1159, 1160, 1165, 1166, 1167, 1171, 1172, 1173, 1174, 1177, 1178, 1179, 1180, 1185, 1186, 1187, 1188, 1191, 1192, 1196, 1197, 1198, 1199, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1219, 1220, 1221, 1222, 1223, 1231, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1285, 1286, 1288, 1290, 1291, 1295, 1296, 1297, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1327, 1328, 1330, 1332, 1333, 1334, 1335, 1359, 1360, 1379, 1380, 1381, 1382, 1392, 1393, 1394, 1395, 1406, 1407, 1408, 1409, 1488, 1613, 1614, 1615, 1616, 1619, 1620, 1621, 1622, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1726, 1727, 1752, 1753, 1754, 1755, 1760, 1761, 1762, 1772, 1783, 1784, 1785, 1786, 1822, 1828, 1829, 1833, 1834, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1853, 1855, 1856, 1857, 1858, 1859, 1870, 1882, 1883, 1884, 1885, 1912, 1913, 1916, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1940, 1941, 1960, 1961, 1962, 1963, 1982, 1983, 1984, 1985, 2000, 2001, 2002, 2003, 2005, 2006, 2007, 2008, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2027, 2028, 2031, 2032, 2036, 2084, 2085, 2086, 2087, 2088, 2089, 2090, 2091, 2123, 2124, 2128, 2129, 2130, 2131, 2132, 2133, 2144, 2145, 2146, 2147, 2149, 2150, 2151, 2165, 2166, 2167, 2168, 2176, 2177, 2178, 2179, 2180, 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2197, 2198, 2199, 2200, 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2230, 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247, 2248, 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259, 2260, 2261, 2262, 2263, 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2300, 2301, 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2375, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2455, 2456, 2457, 2458, 2459, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536, 2537, 2538, 2539, 2540, 2541, 2542, 2543, 2544, 2545, 2546, 2547, 2548, 2549, 2550, 2551, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2560, 2561, 2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583, 2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594, 2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604, 2605, 2606, 2607, 2608, 2609, 2610, 2611, 2612, 2613, 2614, 2615, 2616, 2617, 2618, 2619, 2620, 2621, 2622, 2623, 2624, 2625, 2626, 2627, 2628, 2629, 2630, 2631, 2632, 2633, 2634, 2635, 2636, 2637, 2638, 2639, 2640, 2641, 2642, 2643, 2644, 2645, 2646, 2647, 2648, 2649, 2650, 2651, 2652, 2653, 2654, 2655, 2656, 2657, 2658, 2659, 2660, 2661, 2662, 2663, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2671, 2672, 2673, 2674, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2683, 2684, 2685, 2686, 2687, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2707, 2708, 2709, 2710, 2711, 2712, 2713, 2714, 2715, 2716, 2717, 2718, 2719, 2720, 2721, 2722, 2723, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740, 2741, 2742, 2743, 2744, 2745, 2746, 2747, 2748, 2749, 2750, 2751, 2752, 2753, 2754, 2755, 2756, 2757, 2758, 2759, 2760, 2761, 2762, 2763, 2764, 2765, 2766, 2767, 2768, 2769, 2770, 2771, 2772, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2788, 2789, 2790, 2791, 2792, 2793, 2794, 2795, 2796, 2797, 2798, 2799, 2800, 2801, 2802, 2803, 2804, 2805, 2806, 2807, 2808, 2809, 2810, 2811, 2812, 2813, 2814, 2815, 2816, 2817, 2818, 2819, 2820, 2821, 2822, 2823, 2824, 2825, 2826, 2827, 2828, 2829, 2830, 2831, 2832, 2833, 2834, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2850, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860, 2861, 2862, 2863, 2864, 2865, 2866, 2867, 2868, 2869, 2870, 2871, 2872, 2873, 2874, 2875, 2876, 2877, 2878, 2879, 2880, 2881, 2882, 2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2893, 2894, 2895, 2896, 2897, 2898, 2899, 2900, 2901, 2902, 2903, 2904, 2905, 2906, 2907, 2908, 2909, 2910, 2911, 2912, 2913, 2914, 2915, 2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2925, 2926, 2927, 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2946, 2947, 2948, 2949, 2950, 2951, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2961, 2962, 2963, 2964, 2965, 2966, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2974, 2975, 2976, 2977, 2978, 2979, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987, 2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995, 2996, 2997, 2998, 2999, 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012, 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029, 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063, 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097, 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107, 3108, 3109, 3110, 3111, 3112, 3113, 3114, 3115, 3116, 3117, 3118, 3119, 3120, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3131, 3132, 3133, 3134, 3135, 3136, 3137, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3145, 3146, 3147, 3148, 3149, 3150, 3151, 3152, 3153, 3154, 3155, 3156, 3157, 3158, 3159, 3160, 3161, 3162, 3163, 3164, 3165, 3166, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3175, 3176, 3177, 3178, 3179, 3180, 3181, 3182, 3183, 3184, 3185, 3222, 3223, 3248, 3249, 3275, 3276, 3277, 3278, 3281, 3282, 3283, 3284, 3285, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3319, 3320, 3321, 3322, 3323, 3324, 3325, 3326, 3327, 3328, 3329, 3330, 3331, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3351, 3352, 3353, 3354, 3355, 3356, 3357, 3358, 3359, 3360, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3370, 3371, 3372, 3373, 3374, 3375, 3376, 3377, 3378, 3379, 3380, 3381, 3382, 3383, 3384, 3385, 3386, 3387, 3388, 3389, 3390, 3391, 3392, 3393, 3394, 3395, 3396, 3397, 3398, 3399, 3400, 3401, 3402, 3403, 3404, 3405, 3406, 3407, 3408, 3409, 3410, 3411, 3412, 3413, 3414, 3415, 3416, 3417, 3418, 3419, 3420, 3421, 3422, 3423, 3424, 3425, 3426, 3427, 3428, 3429, 3430, 3431, 3432, 3433, 3434, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3450, 3451, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460, 3461, 3462, 3463, 3464, 3465, 3466, 3467, 3468, 3469, 3470, 3471, 3472, 3473, 3474, 3475, 3476, 3477, 3478, 3479, 3480, 3481, 3482, 3483, 3484, 3485, 3486, 3487, 3488, 3489, 3490, 3491, 3492, 3493, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3585, 3586, 3589, 3590, 3591, 3592, 3597, 3602, 3603, 3606, 3607, 3608, 3609, 3610, 3612, 3613, 3615, 3616, 3617, 3618, 3619, 3620, 3621, 3622, 3627, 3631, 3632, 3633, 3638, 3639, 3640, 3641, 3642, 3645, 3647, 3648, 3651, 3657, 3661, 3668, 3669, 3674, 3675, 3682, 3683, 3684, 3686, 3687, 3688, 3689, 3690, 3692, 3694, 3696, 3699, 3700, 3702, 3704, 3705, 3706, 3708, 3710, 3711, 3712, 3718, 3719, 3720, 3721, 3723, 3729, 3731, 3732, 3733, 3735, 3736, 3741, 3743, 3744, 3746, 3747, 3748, 3749, 3750, 3751, 3755, 3758, 3759, 3763, 3764, 3765, 3766, 3767, 3768, 3770, 3773, 3774, 3775, 3776, 3777, 3778, 3779, 3780, 3781, 3782, 3783, 3784, 3785, 3786, 3787, 3788, 3789, 3790, 3791, 3792, 3793, 3794, 3795, 3796, 3797, 3798, 3799, 3800, 3801, 3802, 3803, 3804, 3805, 3806, 3930, 4477, 4478, 4479, 4480, 4481, 4482, 4483, 4484, 4485, 4486, 4487, 4488, 4489, 4490, 4491, 4492, 4493, 4494, 4495, 4496, 4497, 4498, 4499, 4500, 4501, 4502, 4503, 4504, 4505, 4506, 4507, 4508, 4509, 4510, 4511, 4512, 4513, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4521, 4522, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4540, 4541, 4542, 4543, 4544, 4545, 4546, 4547, 4548, 4549, 4550, 4551, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4568, 4569, 4570, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4585, 4586, 4587, 4588, 4589, 4590, 4591, 4592, 4593, 4594, 4595, 4596, 4597, 4598, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606, 4607, 4608, 4609, 4610, 4611, 4612, 4613, 4614, 4615, 4616, 4617, 4618, 4619, 4620, 4621, 4622, 4623, 4624, 4625, 4626, 4627, 4628, 4629, 4630, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4639, 4640, 4641, 4642, 4643, 4644, 4645, 4646, 4647, 4648, 4649, 4650, 4651, 4652, 4653, 4654, 4655, 4656, 4657, 4658, 4659, 4660, 4661, 4662, 4663, 4664, 4665, 4666, 4667, 4668, 4669, 4670, 4671, 4672, 4673, 4674, 4675, 4676, 4677, 4678, 4679, 4680, 4681, 4682, 4683, 4684, 4685, 4686, 4687, 4688, 4689, 4690, 4691, 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4699, 4700, 4701, 4702, 4703, 4704, 4705, 4706, 4707, 4708, 4709, 4710, 4711, 4712, 4713, 4714, 4715, 4716, 4717, 4718, 4719, 4720, 4721, 4722, 4723, 4724, 4725, 4726, 4727, 4728, 4729, 4730, 4731, 4732, 4733, 4734, 4735, 4736, 4737, 4738, 4739, 4740, 4741, 4742, 4743, 4744, 4745, 4746, 4747, 4748, 4749, 4750, 4751, 4752, 4753, 4754, 4755, 4756, 4757, 4758, 4759, 4760, 4761, 4762, 4763, 4764, 4765, 4766, 4767, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4792, 4793, 4794, 4795, 4796, 4797, 4798, 4799, 4800, 4801, 4802, 4803, 4804, 4805, 4806, 4807, 4808, 4809, 4810, 4811, 4812, 4813, 4814, 4815, 4816, 4817, 4818, 4819, 4820, 4821, 4822, 4823, 4824, 4825, 4826, 4827, 4828, 4829, 4830, 4831, 4832, 4833, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4843, 4844, 4845, 4846, 4847, 4848, 4849, 4850, 4851, 4852, 4853, 4854, 4855, 4856, 4857, 4858, 4859, 4860, 4861, 4862, 4863, 4864, 4865, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4876, 4877, 4878, 4879, 4880, 4881, 4882, 4883, 4884, 4885, 4886, 4887, 4888, 4889, 4890, 4891, 4892, 4893, 4894, 4895, 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913, 4914, 4915, 4916, 4917, 4918, 4919, 4920, 4921, 4922, 4923, 4924, 4925, 4926, 4927, 4928, 4929, 4930, 4931, 4932, 4933, 4934, 4935, 4936, 4937, 4938, 4939, 4940, 4941, 4942, 4943, 4944, 4945, 4946, 4947, 4948, 4949, 4950, 4951, 4952, 4953, 4954, 4955, 4956, 4957, 4958, 4959, 4960, 4961, 4962, 4963, 4964, 4965, 4966, 4967, 4968, 4969, 4970, 4971, 4972, 4973, 4974, 4975, 4976, 4977, 4978, 4979, 4980, 4981, 4982, 4983, 4984, 4985, 4986, 4987, 4988, 4989, 4990, 4991, 4992, 4993, 4994, 4995, 4996, 4997, 4998, 4999, 5000, 5001, 5002, 5003, 5004, 5005, 5006, 5007, 5008, 5009, 5010, 5011, 5012, 5013, 5014, 5015, 5016, 5017, 5018, 5019, 5020, 5021, 5022
+ ])
+ )
+
+ # remove the intersection with neck from scalp and get the region for hair
+ face_and_neck = torch.cat([self.v.face, self.v.neck]).unique()
+ # get the intersection between scalp and face_and_neck
+ uniques, counts = torch.cat([self.v.scalp, face_and_neck]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.scalp, intersection]).unique(return_counts=True)
+ hair = uniques[counts == 1]
+ self.v.register_buffer("hair", hair)
+
+ # unions
+ self.v.register_buffer("ears", torch.cat([self.v.right_ear, self.v.left_ear]))
+ self.v.register_buffer("eyeballs", torch.cat([self.v.right_eyeball, self.v.left_eyeball]))
+ self.v.register_buffer("irises", torch.cat([self.v.right_iris, self.v.left_iris]))
+ self.v.register_buffer("left_eye", torch.cat([self.v.left_eye_region, self.v.left_eyeball]))
+ self.v.register_buffer("right_eye", torch.cat([self.v.right_eye_region, self.v.right_eyeball]))
+ self.v.register_buffer("eyelids", torch.cat([self.v.left_eyelid, self.v.right_eyelid]))
+ self.v.register_buffer("lip_inside_ring", torch.cat([self.v.lip_inside_ring_upper, self.v.lip_inside_ring_lower, torch.tensor([1594, 2730])]))
+
+ # remove the intersection with irises from eyeballs and get the region for scleras
+ uniques, counts = torch.cat([self.v.eyeballs, self.v.irises]).unique(return_counts=True)
+ intersection = uniques[counts == 2]
+ uniques, counts = torch.cat([self.v.eyeballs, intersection]).unique(return_counts=True)
+ sclerae = uniques[counts == 1]
+ self.v.register_buffer("sclerae", sclerae)
+
+ # skin
+ skin_except = ["eyeballs", "hair", "lips_tight", "boundary"]
+ if self.num_verts == 5083:
+ skin_except.append("teeth")
+ skin = self.get_vid_except_region(skin_except)
+ self.v.register_buffer("skin", skin)
+
+ def construct_vid_table(self):
+ self.vid_to_region = defaultdict(list) # vertex id -> region name
+ for region_name, v_mask in self.v:
+ for v_id in v_mask:
+ self.vid_to_region[v_id.item()].append(region_name)
+
+ def process_face_mask(self, faces):
+ logger.info("Processing face masks for FLAME...")
+
+ face_masks = defaultdict(list) # region name -> face id
+ for f_id, f in enumerate(faces):
+ counters = defaultdict(int)
+ for v_id in f:
+ for region_name in self.vid_to_region[v_id.item()]:
+ counters[region_name] += 1
+
+ for region_name, count in counters.items():
+ if count >= 3: # create straight boundaries, with seams
+ # if count > 1: # create zigzag boundaries, no seams
+ face_masks[region_name].append(f_id)
+
+ self.f = BufferContainer()
+ for region_name, f_mask in face_masks.items():
+ self.f.register_buffer(region_name, torch.tensor(f_mask, dtype=torch.long))
+
+ def process_face_clusters(self, face_clusters):
+ """ Construct a lookup table from face id to cluster id.
+
+ cluster #0: background
+ cluster #1: foreground
+ cluster #2: faces in face_clusters[0]
+ cluster #3: faces in face_clusters[1]
+ ...
+ """
+ logger.info("Processing face clusters...")
+
+ fid2cid = torch.ones(self.num_faces+1, dtype=torch.long) # faces are always treated as foreground
+ for cid, cluster in enumerate(face_clusters):
+ try:
+ fids = self.get_fid_by_region([cluster])
+ except Exception as e:
+ logger.warning(f"Ignoring unknown cluster {cluster}.")
+ continue
+ fid2cid[fids] = cid + 2 # reserve cluster #0 for the background and #1 for faces that do not belong to any cluster
+ self.register_buffer("fid2cid", fid2cid)
+
+ def process_vt_mask(self, faces, faces_t):
+ logger.info("Processing vt masks for FLAME...")
+
+ vt_masks = defaultdict(list) # region name -> vt id
+ for f_id, (face, face_t) in enumerate(zip(faces, faces_t)):
+ for v_id, vt_id in zip(face, face_t):
+ for region_name in self.vid_to_region[v_id.item()]:
+ vt_masks[region_name].append(vt_id.item())
+
+ self.vt = BufferContainer()
+ for region_name, vt_mask in vt_masks.items():
+ self.vt.register_buffer(region_name, torch.tensor(vt_mask, dtype=torch.long))
+
+ def get_vid_by_region(self, regions, keep_order=False):
+ """Get vertex indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ vid = torch.cat([self.v.get_buffer(k) for k in regions])
+ if keep_order:
+ return vid
+ else:
+ return vid.unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_vid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.v.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the vertex indicies that are not included by regions
+ vert_idx = torch.arange(0, self.num_verts, device=indices.device)
+ combined = torch.cat((indices, vert_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_by_region(self, regions):
+ """Get face indicies by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ return torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ return torch.tensor([], dtype=torch.long)
+
+ def get_fid_except_region(self, regions):
+ if isinstance(regions, str):
+ regions = [regions]
+ if len(regions) > 0:
+ indices = torch.cat([self.f.get_buffer(k) for k in regions]).unique()
+ else:
+ indices = torch.tensor([], dtype=torch.long)
+
+ # get the face indicies that are not included by regions
+ face_idx = torch.arange(0, self.num_faces, device=indices.device)
+ combined = torch.cat((indices, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+ def get_fid_except_fids(self, fids):
+ # get the face indicies that are not included
+ face_idx = torch.arange(0, self.num_faces, device=fids.device)
+ combined = torch.cat((fids, face_idx))
+ uniques, counts = combined.unique(return_counts=True)
+ return uniques[counts == 1]
+
+
+class FlameUvMask(BufferContainer):
+ def __init__(self, uv_mask_path=FLAME_UVMASK_PATH):
+ super().__init__()
+ logger.info("Processing uv masks for FLAME...")
+
+ uv_masks = np.load(uv_mask_path, allow_pickle=True, encoding="latin1")
+ for region_name, uv_mask in uv_masks.items():
+ self.register_buffer(region_name, torch.tensor(uv_mask, dtype=torch.bool))
+
+ def get_uvmask_by_region(self, regions):
+ """Get uv masks by regions"""
+ if isinstance(regions, str):
+ regions = [regions]
+ return torch.cat([self.get_buffer(k)[..., None] for k in regions], dim=-1).max(dim=-1)[0]
diff --git a/vhap/model/lbs.py b/vhap/model/lbs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c377f6484a67e75b36fb35ccf28200ab122d5af
--- /dev/null
+++ b/vhap/model/lbs.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+
+
+def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
+ """Calculates the rotation matrices for a batch of rotation vectors
+ Parameters
+ ----------
+ rot_vecs: torch.tensor Nx3
+ array of N axis-angle vectors
+ Returns
+ -------
+ R: torch.tensor Nx3x3
+ The rotation matrices for the given axis-angle parameters
+ """
+
+ batch_size = rot_vecs.shape[0]
+ device = rot_vecs.device
+
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
+ rot_dir = rot_vecs / angle
+
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
+
+ # Bx1 arrays
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
+
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
+ (batch_size, 3, 3)
+ )
+
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ return rot_mat
+
+
+def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
+ """Calculates landmarks by barycentric interpolation
+
+ Parameters
+ ----------
+ vertices: torch.tensor BxVx3, dtype = torch.float32
+ The tensor of input vertices
+ faces: torch.tensor Fx3, dtype = torch.long
+ The faces of the mesh
+ lmk_faces_idx: torch.tensor L, dtype = torch.long
+ The tensor with the indices of the faces used to calculate the
+ landmarks.
+ lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
+ The tensor of barycentric coordinates that are used to interpolate
+ the landmarks
+
+ Returns
+ -------
+ landmarks: torch.tensor BxLx3, dtype = torch.float32
+ The coordinates of the landmarks for each mesh in the batch
+ """
+ # Extract the indices of the vertices for each face
+ # BxLx3
+ batch_size, num_verts = vertices.shape[:2]
+ device = vertices.device
+
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
+ batch_size, -1, 3
+ )
+
+ lmk_faces += (
+ torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1)
+ * num_verts
+ )
+
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
+
+ landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords])
+ return landmarks
+
+
+def lbs(
+ pose,
+ v_shaped,
+ posedirs,
+ J_regressor,
+ parents,
+ lbs_weights,
+ pose2rot=True,
+ dtype=torch.float32,
+):
+ """Performs Linear Blend Skinning with the given shape and pose parameters
+
+ Parameters
+ ----------
+ betas : torch.tensor BxNB
+ The tensor of shape parameters
+ pose : torch.tensor Bx(J + 1) * 3
+ The pose parameters in axis-angle format
+ v_template: torch.tensor BxVx3
+ The template mesh that will be deformed
+ shapedirs : torch.tensor 1xNB
+ The tensor of PCA shape displacements
+ posedirs : torch.tensor Px(V * 3)
+ The pose PCA coefficients
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from
+ the position of the vertices
+ parents: torch.tensor J
+ The array that describes the kinematic tree for the model
+ lbs_weights: torch.tensor N x V x (J + 1)
+ The linear blend skinning weights that represent how much the
+ rotation matrix of each part affects each vertex
+ pose2rot: bool, optional
+ Flag on whether to convert the input pose tensor to rotation
+ matrices. The default value is True. If False, then the pose tensor
+ should already contain rotation matrices and have a size of
+ Bx(J + 1)x9
+ dtype: torch.dtype, optional
+
+ Returns
+ -------
+ verts: torch.tensor BxVx3
+ The vertices of the mesh after applying the shape and pose
+ displacements.
+ joints: torch.tensor BxJx3
+ The joints of the model
+ """
+
+ batch_size = pose.shape[0]
+ device = pose.device
+
+ # Get the joints
+ # NxJx3 array
+ J = vertices2joints(J_regressor, v_shaped)
+
+ # 3. Add pose blend shapes
+ # N x J x 3 x 3
+ ident = torch.eye(3, dtype=dtype, device=device)
+ if pose2rot:
+ rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view(
+ [batch_size, -1, 3, 3]
+ )
+
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
+ # (N x P) x (P, V * 3) -> N x V x 3
+ pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
+ else:
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
+ rot_mats = pose.view(batch_size, -1, 3, 3)
+
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(
+ batch_size, -1, 3
+ )
+
+ v_posed = pose_offsets + v_shaped
+
+ # 4. Get the global joint location
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones(
+ [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
+ )
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+
+ verts = v_homo[:, :, :3, 0]
+
+ return verts, J_transformed, A[:, 1]
+
+
+def vertices2joints(J_regressor, vertices):
+ """Calculates the 3D joint locations from the vertices
+
+ Parameters
+ ----------
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from the
+ position of the vertices
+ vertices : torch.tensor BxVx3
+ The tensor of mesh vertices
+
+ Returns
+ -------
+ torch.tensor BxJx3
+ The location of the joints
+ """
+
+ return torch.einsum("bik,ji->bjk", [vertices, J_regressor])
+
+
+def blend_shapes(betas, shape_disps):
+ """Calculates the per vertex displacement due to the blend shapes
+
+
+ Parameters
+ ----------
+ betas : torch.tensor Bx(num_betas)
+ Blend shape coefficients
+ shape_disps: torch.tensor Vx3x(num_betas)
+ Blend shapes
+
+ Returns
+ -------
+ torch.tensor BxVx3
+ The per-vertex displacement due to shape deformation
+ """
+
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
+ # i.e. Multiply each shape displacement by its corresponding beta and
+ # then sum them.
+ blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps])
+ return blend_shape
+
+
+def transform_mat(R, t):
+ """Creates a batch of transformation matrices
+ Args:
+ - R: Bx3x3 array of a batch of rotation matrices
+ - t: Bx3x1 array of a batch of translation vectors
+ Returns:
+ - T: Bx4x4 Transformation matrix
+ """
+ # No padding left or right, only add an extra row
+ return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
+
+
+def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
+ """
+ Applies a batch of rigid transformations to the joints
+
+ Parameters
+ ----------
+ rot_mats : torch.tensor BxNx3x3
+ Tensor of rotation matrices
+ joints : torch.tensor BxNx3
+ Locations of joints
+ parents : torch.tensor BxN
+ The kinematic tree of each object
+ dtype : torch.dtype, optional:
+ The data type of the created tensors, the default is torch.float32
+
+ Returns
+ -------
+ posed_joints : torch.tensor BxNx3
+ The locations of the joints after applying the pose rotations
+ rel_transforms : torch.tensor BxNx4x4
+ The relative (with respect to the root joint) rigid transformations
+ for all the joints
+ """
+
+ joints = torch.unsqueeze(joints, dim=-1)
+
+ rel_joints = joints.clone().contiguous()
+ rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]]
+
+ transforms_mat = transform_mat(rot_mats.view(-1, 3, 3), rel_joints.view(-1, 3, 1))
+ transforms_mat = transforms_mat.view(-1, joints.shape[1], 4, 4)
+
+ transform_chain = [transforms_mat[:, 0]]
+ for i in range(1, parents.shape[0]):
+ # Subtract the joint location at the rest pose
+ # No need for rotation, since it's identity when at rest
+ curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
+ transform_chain.append(curr_res)
+
+ transforms = torch.stack(transform_chain, dim=1)
+
+ # The last column of the transformations contains the posed joints
+ posed_joints = transforms[:, :, :3, 3]
+
+ joints_homogen = F.pad(joints, [0, 0, 0, 1])
+
+ rel_transforms = transforms - F.pad(
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
+ )
+
+ return posed_joints, rel_transforms
diff --git a/vhap/model/tracker.py b/vhap/model/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..57e88754fdb31af9faa1514fd0241d5b3180b877
--- /dev/null
+++ b/vhap/model/tracker.py
@@ -0,0 +1,1570 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from vhap.config.base import import_module, PhotometricStageConfig, BaseTrackingConfig
+from vhap.model.flame import FlameHead, FlameTexPCA, FlameTexPainted, FlameUvMask
+from vhap.model.lbs import batch_rodrigues
+from vhap.util.mesh import (
+ get_mtl_content,
+ get_obj_content,
+ normalize_image_points,
+)
+from vhap.util.log import get_logger
+from vhap.util.visualization import plot_landmarks_2d
+
+from torch.utils.tensorboard import SummaryWriter
+import torch
+import torchvision
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import numpy as np
+from matplotlib import cm
+from typing import Literal
+from functools import partial
+import tyro
+import yaml
+from datetime import datetime
+import threading
+from typing import Optional
+from collections import defaultdict
+from copy import deepcopy
+import time
+import os
+
+
+class FlameTracker:
+ def __init__(self, cfg: BaseTrackingConfig):
+ self.cfg = cfg
+
+ self.device = cfg.device
+ self.tb_writer = None
+
+ # model
+ self.flame = FlameHead(
+ cfg.model.n_shape,
+ cfg.model.n_expr,
+ add_teeth=cfg.model.add_teeth,
+ remove_lip_inside=cfg.model.remove_lip_inside,
+ face_clusters=cfg.model.tex_clusters,
+ ).to(self.device)
+
+ if cfg.model.tex_painted:
+ self.flame_tex_painted = FlameTexPainted(tex_size=cfg.model.tex_resolution).to(self.device)
+ else:
+ self.flame_tex_pca = FlameTexPCA(cfg.model.n_tex, tex_size=cfg.model.tex_resolution).to(self.device)
+
+ self.flame_uvmask = FlameUvMask().to(self.device)
+
+ # renderer for visualization, dense photometric energy
+ if self.cfg.render.backend == 'nvdiffrast':
+ from vhap.util.render_nvdiffrast import NVDiffRenderer
+
+ self.render = NVDiffRenderer(
+ use_opengl=self.cfg.render.use_opengl,
+ lighting_type=self.cfg.render.lighting_type,
+ lighting_space=self.cfg.render.lighting_space,
+ disturb_rate_fg=self.cfg.render.disturb_rate_fg,
+ disturb_rate_bg=self.cfg.render.disturb_rate_bg,
+ fid2cid=self.flame.mask.fid2cid,
+ )
+ elif self.cfg.render.backend == 'pytorch3d':
+ from vhap.util.render_pytorch3d import PyTorch3DRenderer
+
+ self.render = PyTorch3DRenderer()
+ else:
+ raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}")
+
+ def load_from_tracked_flame_params(self, fp):
+ """
+ loads checkpoint from tracked_flame_params file. Counterpart to save_result()
+ :param fp:
+ :return:
+ """
+ report = np.load(fp)
+
+ # LOADING PARAMETERS
+ def load_param(param, ckpt_array):
+ param.data[:] = torch.from_numpy(ckpt_array).to(param.device)
+
+ def load_param_list(param_list, ckpt_array):
+ for i in range(min(len(param_list), len(ckpt_array))):
+ load_param(param_list[i], ckpt_array[i])
+
+ load_param_list(self.rotation, report["rotation"])
+ load_param_list(self.translation, report["translation"])
+ load_param_list(self.neck_pose, report["neck_pose"])
+ load_param_list(self.jaw_pose, report["jaw_pose"])
+ load_param_list(self.eyes_pose, report["eyes_pose"])
+ load_param(self.shape, report["shape"])
+ load_param_list(self.expr, report["expr"])
+ load_param(self.lights, report["lights"])
+ # self.frame_idx = report["n_processed_frames"]
+ if not self.calibrated:
+ load_param(self.focal_length, report["focal_length"])
+
+ if not self.cfg.model.tex_painted:
+ if "tex" in report:
+ load_param(self.tex_pca, report["tex"])
+ else:
+ self.logger.warn("No tex_extra found in flame_params!")
+
+ if self.cfg.model.tex_extra:
+ if "tex_extra" in report:
+ load_param(self.tex_extra, report["tex_extra"])
+ else:
+ self.logger.warn("No tex_extra found in flame_params!")
+
+ if self.cfg.model.use_static_offset:
+ if "static_offset" in report:
+ load_param(self.static_offset, report["static_offset"])
+ else:
+ self.logger.warn("No static_offset found in flame_params!")
+
+ if self.cfg.model.use_dynamic_offset:
+ if "dynamic_offset" in report:
+ load_param_list(self.dynamic_offset, report["dynamic_offset"])
+ else:
+ self.logger.warn("No dynamic_offset found in flame_params!")
+
+ def trimmed_decays(self, is_init):
+ decays = {}
+ for k, v in self.decays.items():
+ if is_init and "init" in k or not is_init and "init" not in k:
+ decays[k.replace("_init", "")] = v
+ return decays
+
+ def clear_cache(self):
+ self.render.clear_cache()
+
+ def get_current_frame(self, frame_idx, include_keyframes=False):
+ """
+ Creates a single item batch from the frame data at index frame_idx in the dataset.
+ If include_keyframes option is set, keyframe data will be appended to the batch. However,
+ it is guaranteed that the frame data belonging to frame_idx is at position 0
+ :param frame_idx:
+ :return:
+ """
+ indices = [frame_idx]
+ if include_keyframes:
+ indices += self.cfg.exp.keyframes
+
+ samples = []
+ for idx in indices:
+ sample = self.dataset.getitem_by_timestep(idx)
+ # sample["timestep_index"] = idx
+
+ # for k, v in sample.items():
+ # if isinstance(v, torch.Tensor):
+ # sample[k] = v[None, ...].to(self.device)
+
+ samples.append(sample)
+
+ # if also keyframes have been loaded, stack all data
+ sample = {}
+ for k, v in samples[0].items():
+ values = [s[k] for s in samples]
+ if isinstance(v, torch.Tensor):
+ values = torch.cat(values, dim=0)
+ sample[k] = values
+
+ if "lmk2d_iris" in sample:
+ sample["lmk2d"] = torch.cat([sample["lmk2d"], sample["lmk2d_iris"]], dim=1)
+ return sample
+
+ def fill_cam_params_into_sample(self, sample):
+ """
+ Adds intrinsics and extrinics to sample, if data is not calibrated
+ """
+ if self.calibrated:
+ assert "intrinsic" in sample
+ assert "extrinsic" in sample
+ else:
+ b, _, h, w = sample["rgb"].shape
+ # K = torch.eye(3, 3).to(self.device)
+
+ # denormalize cam params
+ f = self.focal_length * max(h, w)
+ cx, cy = torch.tensor([[0.5*w], [0.5*h]]).to(f)
+
+ sample["intrinsic"] = torch.stack([f, f, cx, cy], dim=1)
+ sample["extrinsic"] = self.RT[None, ...].expand(b, -1, -1)
+
+ def configure_optimizer(self, params, lr_scale=1.0):
+ """
+ Creates optimizer for the given set of parameters
+ :param params:
+ :return:
+ """
+ # copy dict because we will call 'pop'
+ params = params.copy()
+ param_groups = []
+ default_lr = self.cfg.lr.base
+
+ # dict map group name to param dict keys
+ group_def = {
+ "translation": ["translation"],
+ "expr": ["expr"],
+ "light": ["lights"],
+ }
+ if not self.calibrated:
+ group_def ["cam"] = ["cam"]
+ if self.cfg.model.use_static_offset:
+ group_def ["static_offset"] = ["static_offset"]
+ if self.cfg.model.use_dynamic_offset:
+ group_def ["dynamic_offset"] = ["dynamic_offset"]
+
+ # dict map group name to lr
+ group_lr = {
+ "translation": self.cfg.lr.translation,
+ "expr": self.cfg.lr.expr,
+ "light": self.cfg.lr.light,
+ }
+ if not self.calibrated:
+ group_lr["cam"] = self.cfg.lr.camera
+ if self.cfg.model.use_static_offset:
+ group_lr["static_offset"] = self.cfg.lr.static_offset
+ if self.cfg.model.use_dynamic_offset:
+ group_lr["dynamic_offset"] = self.cfg.lr.dynamic_offset
+
+ for group_name, param_keys in group_def.items():
+ selected = []
+ for p in param_keys:
+ if p in params:
+ selected += params.pop(p)
+ if len(selected) > 0:
+ param_groups.append({"params": selected, "lr": group_lr[group_name] * lr_scale})
+
+ # create default group with remaining params
+ selected = []
+ for _, v in params.items():
+ selected += v
+ param_groups.append({"params": selected})
+
+ optim = torch.optim.Adam(param_groups, lr=default_lr * lr_scale)
+ return optim
+
+ def initialize_frame(self, frame_idx):
+ """
+ Initializes parameters of frame frame_idx
+ :param frame_idx:
+ :return:
+ """
+ if frame_idx > 0:
+ self.initialize_from_previous(frame_idx)
+
+ def initialize_from_previous(self, frame_idx):
+ """
+ Initializes the flame parameters with the optimized ones from the previous frame
+ :param frame_idx:
+ :return:
+ """
+ if frame_idx == 0:
+ return
+
+ param_list = [
+ self.expr,
+ self.neck_pose,
+ self.jaw_pose,
+ self.translation,
+ self.rotation,
+ self.eyes_pose,
+ ]
+
+ for param in param_list:
+ param[frame_idx].data = param[frame_idx - 1].detach().clone().data
+
+ def select_frame_indices(self, frame_idx, include_keyframes):
+ indices = [frame_idx]
+ if include_keyframes:
+ indices += self.cfg.exp.keyframes
+ return indices
+
+ def forward_flame(self, frame_idx, include_keyframes):
+ """
+ Evaluates the flame model using the given parameters
+ :param flame_params:
+ :return:
+ """
+ indices = self.select_frame_indices(frame_idx, include_keyframes)
+
+ dynamic_offset = self.to_batch(self.dynamic_offset, indices) if self.cfg.model.use_dynamic_offset else None
+
+ ret = self.flame(
+ self.shape[None, ...].expand(len(indices), -1),
+ self.to_batch(self.expr, indices),
+ self.to_batch(self.rotation, indices),
+ self.to_batch(self.neck_pose, indices),
+ self.to_batch(self.jaw_pose, indices),
+ self.to_batch(self.eyes_pose, indices),
+ self.to_batch(self.translation, indices),
+ return_verts_cano=True,
+ static_offset=self.static_offset,
+ dynamic_offset=dynamic_offset,
+ )
+ verts, verts_cano, lmks = ret[0], ret[1], ret[2]
+ albedos = self.get_albedo().expand(len(indices), -1, -1, -1)
+ return verts, verts_cano, lmks, albedos
+
+ def get_base_texture(self):
+ if self.cfg.model.tex_extra and not self.cfg.model.residual_tex:
+ albedos_base = self.tex_extra[None, ...]
+ else:
+ if self.cfg.model.tex_painted:
+ albedos_base = self.flame_tex_painted()
+ else:
+ albedos_base = self.flame_tex_pca(self.tex_pca[None, :])
+ return albedos_base
+
+ def get_albedo(self):
+ albedos_base = self.get_base_texture()
+
+ if self.cfg.model.tex_extra and self.cfg.model.residual_tex:
+ albedos_res = self.tex_extra[None, :]
+ if albedos_base.shape[-1] != albedos_res.shape[-1] or albedos_base.shape[-2] != albedos_res.shape[-2]:
+ albedos_base = F.interpolate(albedos_base, albedos_res.shape[-2:], mode='bilinear')
+ albedos = albedos_base + albedos_res
+ else:
+ albedos = albedos_base
+
+ return albedos
+
+ def rasterize_flame(
+ self, sample, verts, faces, camera_index=None, train_mode=False
+ ):
+ """
+ Rasterizes the flame head mesh
+ :param verts:
+ :param albedos:
+ :param K:
+ :param RT:
+ :param resolution:
+ :param use_cache:
+ :return:
+ """
+ # cameras parameters
+ K = sample["intrinsic"].clone().to(self.device)
+ RT = sample["extrinsic"].to(self.device)
+ if camera_index is not None:
+ K = K[[camera_index]]
+ RT = RT[[camera_index]]
+
+ H, W = self.image_size
+ image_size = H, W
+
+ # rasterize fragments
+ rast_dict = self.render.rasterize(verts, faces, RT, K, image_size, False, train_mode)
+ return rast_dict
+
+ @torch.no_grad()
+ def get_background_color(self, gt_rgb, gt_alpha, stage):
+ if stage is None: # when stage is None, it means we are in the evaluation mode
+ background = self.cfg.render.background_eval
+ else:
+ background = self.cfg.render.background_train
+
+ if background == 'target':
+ """use gt_rgb as background"""
+ color = gt_rgb.permute(0, 2, 3, 1)
+ elif background == 'white':
+ color = [1, 1, 1]
+ elif background == 'black':
+ color = [0, 0, 0]
+ else:
+ raise NotImplementedError(f"Unknown background mode: {background}")
+ return color
+
+ def render_rgba(
+ self, rast_dict, verts, faces, albedos, lights, background_color=[1, 1, 1],
+ align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False,
+ ):
+ """
+ Renders the rgba image from the rasterization result and
+ the optimized texture + lights
+ """
+ faces_uv = self.flame.textures_idx
+ if self.cfg.render.backend == 'nvdiffrast':
+ verts_uv = self.flame.verts_uvs.clone()
+ verts_uv[:, 1] = 1 - verts_uv[:, 1]
+ tex = albedos
+
+ render_out = self.render.render_rgba(
+ rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color,
+ align_texture_except_fid, align_boundary_except_vid, enable_disturbance
+ )
+ render_out = {k: v.permute(0, 3, 1, 2) for k, v in render_out.items()}
+ elif self.cfg.render.backend == 'pytorch3d':
+ B = verts.shape[0] # TODO: double check
+ verts_uv = self.flame.face_uvcoords.repeat(B, 1, 1)
+ tex = albedos.expand(B, -1, -1, -1)
+
+ rgba = self.render.render_rgba(
+ rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color
+ )
+ render_out = {'rgba': rgba.permute(0, 3, 1, 2)}
+ else:
+ raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}")
+
+ return render_out
+
+ def render_normal(self, rast_dict, verts, faces):
+ """
+ Renders the rgba image from the rasterization result and
+ the optimized texture + lights
+ """
+ uv_coords = self.flame.face_uvcoords
+ uv_coords = uv_coords.repeat(verts.shape[0], 1, 1)
+ return self.render.render_normal(rast_dict, verts, faces, uv_coords)
+
+ def compute_lmk_energy(self, sample, pred_lmks, disable_jawline_landmarks=False):
+ """
+ Computes the landmark energy loss term between groundtruth landmarks and flame landmarks
+ :param sample:
+ :param pred_lmks:
+ :return: the lmk loss for all 68 facial landmarks, a separate 2 pupil landmark loss and
+ a relative eye close term
+ """
+ img_size = sample["rgb"].shape[-2:]
+
+ # ground-truth landmark
+ lmk2d = sample["lmk2d"].clone().to(pred_lmks)
+ lmk2d, confidence = lmk2d[:, :, :2], lmk2d[:, :, 2]
+ lmk2d[:, :, 0], lmk2d[:, :, 1] = normalize_image_points(
+ lmk2d[:, :, 0], lmk2d[:, :, 1], img_size
+ )
+
+ # predicted landmark
+ K = sample["intrinsic"].to(self.device)
+ RT = sample["extrinsic"].to(self.device)
+ pred_lmk_ndc = self.render.world_to_ndc(pred_lmks, RT, K, img_size, flip_y=True)
+ pred_lmk2d = pred_lmk_ndc[:, :, :2]
+
+ if (lmk2d.shape[1] == 70):
+ diff = lmk2d - pred_lmk2d
+ confidence = confidence[:, :70]
+ # eyes weighting
+ confidence[:, 68:] = confidence[:, 68:] * 2
+ else:
+ diff = lmk2d[:, :68] - pred_lmk2d[:, :68]
+ confidence = confidence[:, :68]
+
+ # compute general landmark term
+ lmk_loss = torch.norm(diff, dim=2, p=1) * confidence
+
+ result_dict = {
+ "gt_lmk2d": lmk2d,
+ "pred_lmk2d": pred_lmk2d,
+ }
+
+ return lmk_loss.mean(), result_dict
+
+ def compute_photometric_energy(
+ self,
+ sample,
+ verts,
+ faces,
+ albedos,
+ rast_dict,
+ step_i=None,
+ stage=None,
+ include_keyframes=False,
+ ):
+ """
+ Computes the dense photometric energy
+ :param sample:
+ :param vertices:
+ :param albedos:
+ :return:
+ """
+ gt_rgb = sample["rgb"].to(verts)
+ if "alpha" in sample:
+ gt_alpha = sample["alpha_map"].to(verts)
+ else:
+ gt_alpha = None
+
+ lights = self.lights[None] if self.lights is not None else None
+ bg_color = self.get_background_color(gt_rgb, gt_alpha, stage)
+
+ align_texture_except_fid = self.flame.mask.get_fid_by_region(
+ self.cfg.pipeline[stage].align_texture_except
+ ) if stage is not None else None
+ align_boundary_except_vid = self.flame.mask.get_vid_by_region(
+ self.cfg.pipeline[stage].align_boundary_except
+ ) if stage is not None else None
+
+ render_out = self.render_rgba(
+ rast_dict, verts, faces, albedos, lights, bg_color,
+ align_texture_except_fid, align_boundary_except_vid,
+ enable_disturbance=stage!=None,
+ )
+
+ pred_rgb = render_out['rgba'][:, :3]
+ pred_alpha = render_out['rgba'][:, 3:]
+ pred_mask = render_out['rgba'][:, [3]].detach() > 0
+ pred_mask = pred_mask.expand(-1, 3, -1, -1)
+
+ results_dict = render_out
+
+ # ---- rgb loss ----
+ error_rgb = gt_rgb - pred_rgb
+ color_loss = error_rgb.abs().sum() / pred_mask.detach().sum()
+
+ results_dict.update(
+ {
+ "gt_rgb": gt_rgb,
+ "pred_rgb": pred_rgb,
+ "error_rgb": error_rgb,
+ "pred_alpha": pred_alpha,
+ }
+ )
+
+ # ---- silhouette loss ----
+ # error_alpha = gt_alpha - pred_alpha
+ # mask_loss = error_alpha.abs().sum()
+
+ # results_dict.update(
+ # {
+ # "gt_alpha": gt_alpha,
+ # "error_alpha": error_alpha,
+ # }
+ # )
+
+ # ---- background loss ----
+ # bg_mask = gt_alpha < 0.5
+ # error_alpha = gt_alpha - pred_alpha
+ # error_alpha = torch.where(bg_mask, error_alpha, torch.zeros_like(error_alpha))
+ # mask_loss = error_alpha.abs().sum() / bg_mask.sum()
+
+ # results_dict.update(
+ # {
+ # "gt_alpha": gt_alpha,
+ # "error_alpha": error_alpha,
+ # }
+ # )
+
+ # --------
+ # photo_loss = color_loss + mask_loss
+ photo_loss = color_loss
+ # photo_loss = mask_loss
+ return photo_loss, results_dict
+
+ def compute_regularization_energy(self, result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage):
+ """
+ Computes the energy term that penalizes strong deviations from the flame base model
+ """
+ log_dict = {}
+
+ std_tex = 1
+ std_expr = 1
+ std_shape = 1
+
+ indices = self.select_frame_indices(frame_idx, include_keyframes)
+
+ # pose smoothness term
+ if self.opt_dict['pose'] and 'tracking' in stage:
+ E_pose_smooth = self.compute_pose_smooth_energy(frame_idx, stage=='global_tracking')
+ log_dict["pose_smooth"] = E_pose_smooth
+
+ # joint regularization term
+ if self.opt_dict['joints']:
+ if 'tracking' in stage:
+ joint_smooth = self.compute_joint_smooth_energy(frame_idx, stage=='global_tracking')
+ log_dict["joint_smooth"] = joint_smooth
+
+ joint_prior = self.compute_joint_prior_energy(frame_idx)
+ log_dict["joint_prior"] = joint_prior
+
+ # expression regularization
+ if self.opt_dict['expr']:
+ expr = self.to_batch(self.expr, indices)
+ reg_expr = (expr / std_expr) ** 2
+ log_dict["reg_expr"] = self.cfg.w.reg_expr * reg_expr.mean()
+
+ # shape regularization
+ if self.opt_dict['shape']:
+ reg_shape = (self.shape / std_shape) ** 2
+ log_dict["reg_shape"] = self.cfg.w.reg_shape * reg_shape.mean()
+
+ # texture regularization
+ if self.opt_dict['texture']:
+ # texture space
+ if not self.cfg.model.tex_painted:
+ reg_tex_pca = (self.tex_pca / std_tex) ** 2
+ log_dict["reg_tex_pca"] = self.cfg.w.reg_tex_pca * reg_tex_pca.mean()
+
+ # texture map
+ if self.cfg.model.tex_extra:
+ if self.cfg.model.residual_tex:
+ if self.cfg.w.reg_tex_res is not None:
+ reg_tex_res = self.tex_extra ** 2
+ # reg_tex_res = self.tex_extra.abs() # L1 loss can create noise textures
+
+ # if len(self.cfg.model.occluded) > 0:
+ # mask = (~self.flame_uvmask.get_uvmask_by_region(self.cfg.model.occluded)).float()[None, ...]
+ # reg_tex_res *= mask
+ log_dict["reg_tex_res"] = self.cfg.w.reg_tex_res * reg_tex_res.mean()
+
+ if self.cfg.w.reg_tex_tv is not None:
+ tex = self.get_albedo()[0] # (3, H, W)
+ tv_y = (tex[..., :-1, :] - tex[..., 1:, :]) ** 2
+ tv_x = (tex[..., :, :-1] - tex[..., :, 1:]) ** 2
+ tv = tv_y.reshape(tv_y.shape[0], -1) + tv_x.reshape(tv_x.shape[0], -1)
+ w_reg_tex_tv = self.cfg.w.reg_tex_tv * self.cfg.data.scale_factor ** 2
+ if self.cfg.data.n_downsample_rgb is not None:
+ w_reg_tex_tv /= (self.cfg.data.n_downsample_rgb ** 2)
+ log_dict["reg_tex_tv"] = w_reg_tex_tv * tv.mean()
+
+ if self.cfg.w.reg_tex_res_clusters is not None:
+ mask_sclerae = self.flame_uvmask.get_uvmask_by_region(self.cfg.w.reg_tex_res_for)[None, :, :]
+ reg_tex_res_clusters = self.tex_extra ** 2 * mask_sclerae
+ log_dict["reg_tex_res_clusters"] = self.cfg.w.reg_tex_res_clusters * reg_tex_res_clusters.mean()
+
+ # lighting parameters regularization
+ if self.opt_dict['lights']:
+ if self.cfg.w.reg_light is not None and self.lights is not None:
+ reg_light = (self.lights - self.lights_uniform) ** 2
+ log_dict["reg_light"] = self.cfg.w.reg_light * reg_light.mean()
+
+ if self.cfg.w.reg_diffuse is not None and self.lights is not None:
+ diffuse = result_dict['diffuse_detach_normal']
+ reg_diffuse = F.relu(diffuse.max() - 1) + diffuse.var(dim=1).mean()
+ log_dict["reg_diffuse"] = self.cfg.w.reg_diffuse * reg_diffuse
+
+ # offset regularization
+ if self.opt_dict['static_offset'] or self.opt_dict['dynamic_offset']:
+ if self.static_offset is not None or self.dynamic_offset is not None:
+ offset = 0
+ if self.static_offset is not None:
+ offset += self.static_offset
+ if self.dynamic_offset is not None:
+ offset += self.to_batch(self.dynamic_offset, indices)
+
+ if self.cfg.w.reg_offset_lap is not None:
+ # laplacian loss
+ vert_wo_offset = (verts_cano - offset).detach()
+ reg_offset_lap = self.compute_laplacian_smoothing_loss(
+ vert_wo_offset, vert_wo_offset + offset
+ )
+ if len(self.cfg.w.reg_offset_lap_relax_for) > 0:
+ w = self.scale_vertex_weights_by_region(
+ weights=torch.ones_like(verts[:, :, :1]),
+ scale_factor=self.cfg.w.reg_offset_lap_relax_coef,
+ region=self.cfg.w.reg_offset_lap_relax_for,
+ )
+ reg_offset_lap *= w
+ log_dict["reg_offset_lap"] = self.cfg.w.reg_offset_lap * reg_offset_lap.mean()
+
+ if self.cfg.w.reg_offset is not None:
+ # norm loss
+ # reg_offset = offset.norm(dim=-1, keepdim=True)
+ reg_offset = offset.abs()
+ if len(self.cfg.w.reg_offset_relax_for) > 0:
+ w = self.scale_vertex_weights_by_region(
+ weights=torch.ones_like(verts[:, :, :1]),
+ scale_factor=self.cfg.w.reg_offset_relax_coef,
+ region=self.cfg.w.reg_offset_relax_for,
+ )
+ reg_offset *= w
+ log_dict["reg_offset"] = self.cfg.w.reg_offset * reg_offset.mean()
+
+ if self.cfg.w.reg_offset_rigid is not None:
+ reg_offset_rigid = 0
+ for region in self.cfg.w.reg_offset_rigid_for:
+ vids = self.flame.mask.get_vid_by_region([region])
+ reg_offset_rigid += offset[:, vids, :].var(dim=-2).mean()
+ log_dict["reg_offset_rigid"] = self.cfg.w.reg_offset_rigid * reg_offset_rigid
+
+ if self.cfg.w.reg_offset_dynamic is not None and self.dynamic_offset is not None and self.opt_dict['dynamic_offset']:
+ # The dynamic offset is regularized to be temporally smooth
+ if frame_idx == 0:
+ reg_offset_d = torch.zeros_like(self.dynamic_offset[0])
+ offset_d = self.dynamic_offset[0]
+ else:
+ reg_offset_d = torch.stack([self.dynamic_offset[0], self.dynamic_offset[frame_idx - 1]])
+ offset_d = self.dynamic_offset[frame_idx]
+
+ reg_offset_dynamic = ((offset_d - reg_offset_d) ** 2).mean()
+ log_dict["reg_offset_dynamic"] = self.cfg.w.reg_offset_dynamic * reg_offset_dynamic
+
+ return log_dict
+
+ def scale_vertex_weights_by_region(self, weights, scale_factor, region):
+ indices = self.flame.mask.get_vid_by_region(region)
+ weights[:, indices] *= scale_factor
+
+ for _ in range(self.cfg.w.blur_iter):
+ M = self.flame.laplacian_matrix_negate_diag[None, ...]
+ weights = M.bmm(weights) / 2
+ return weights
+
+ def compute_pose_smooth_energy(self, frame_idx, use_next_frame=False):
+ """
+ Regularizes the global pose of the flame head model to be temporally smooth
+ """
+ idx = frame_idx
+ idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1)
+ if use_next_frame:
+ idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1)
+ ref_indices = [idx_prev, idx_next]
+ else:
+ ref_indices = [idx_prev]
+
+ E_trans = ((self.translation[[idx]] - self.translation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_trans
+ E_rot = ((self.rotation[[idx]] - self.rotation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_rot
+ return E_trans + E_rot
+
+ def compute_joint_smooth_energy(self, frame_idx, use_next_frame=False):
+ """
+ Regularizes the joints of the flame head model to be temporally smooth
+ """
+ idx = frame_idx
+ idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1)
+ if use_next_frame:
+ idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1)
+ ref_indices = [idx_prev, idx_next]
+ else:
+ ref_indices = [idx_prev]
+
+ E_joint_smooth = 0
+ E_joint_smooth += ((self.neck_pose[[idx]] - self.neck_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_neck
+ E_joint_smooth += ((self.jaw_pose[[idx]] - self.jaw_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_jaw
+ E_joint_smooth += ((self.eyes_pose[[idx]] - self.eyes_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_eyes
+ return E_joint_smooth
+
+ def compute_joint_prior_energy(self, frame_idx):
+ """
+ Regularizes the joints of the flame head model towards neutral joint locations
+ """
+ poses = [
+ ("neck", self.neck_pose[[frame_idx], :]),
+ ("jaw", self.jaw_pose[[frame_idx], :]),
+ ("eyes", self.eyes_pose[[frame_idx], :3]),
+ ("eyes", self.eyes_pose[[frame_idx], 3:]),
+ ]
+
+ # Joints should are regularized towards neural
+ E_joint_prior = 0
+ for name, pose in poses:
+ # L2 regularization for each joint
+ rotmats = batch_rodrigues(torch.cat([torch.zeros_like(pose), pose], dim=0))
+ diff = ((rotmats[[0]] - rotmats[1:]) ** 2).mean()
+
+ # Additional regularization for physical plausibility
+ if name == 'jaw':
+ # penalize negative rotation along x axis of jaw
+ diff += F.relu(-pose[:, 0]).mean() * 10
+
+ # penalize rotation along y and z axis of jaw
+ diff += (pose[:, 1:] ** 2).mean() * 3
+ elif name == 'eyes':
+ # penalize the difference between the two eyes
+ diff += ((self.eyes_pose[[frame_idx], :3] - self.eyes_pose[[frame_idx], 3:]) ** 2).mean()
+
+ E_joint_prior += diff * self.cfg.w[f"prior_{name}"]
+ return E_joint_prior
+
+ def compute_laplacian_smoothing_loss(self, verts, offset_verts):
+ L = self.flame.laplacian_matrix[None, ...].detach() # (1, V, V)
+ basis_lap = L.bmm(verts).detach() #.norm(dim=-1) * weights
+
+ offset_lap = L.bmm(offset_verts) #.norm(dim=-1) # * weights
+ diff = (offset_lap - basis_lap) ** 2
+ diff = diff.sum(dim=-1, keepdim=True)
+ return diff
+
+ def compute_energy(
+ self,
+ sample,
+ frame_idx,
+ include_keyframes=False,
+ step_i=None,
+ stage=None,
+ ):
+ """
+ Compute total energy for frame frame_idx
+ :param sample:
+ :param frame_idx:
+ :param include_keyframes: if key frames shall be included when predicting the per
+ frame energy
+ :return: loss, log dict, predicted vertices and landmarks
+ """
+ log_dict = {}
+
+ gt_rgb = sample["rgb"]
+ result_dict = {"gt_rgb": gt_rgb}
+
+ verts, verts_cano, lmks, albedos = self.forward_flame(frame_idx, include_keyframes)
+ faces = self.flame.faces
+
+ if isinstance(sample["num_cameras"], list):
+ num_cameras = sample["num_cameras"][0]
+ else:
+ num_cameras = sample["num_cameras"]
+ # albedos = self.repeat_n_times(albedos, num_cameras) # only needed for pytorch3d renderer
+
+ if self.cfg.w.landmark is not None:
+ lmks_n = self.repeat_n_times(lmks, num_cameras)
+ if not self.cfg.w.always_enable_jawline_landmarks and stage is not None:
+ disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks']
+ else:
+ disable_jawline_landmarks = False
+ E_lmk, _result_dict = self.compute_lmk_energy(sample, lmks_n, disable_jawline_landmarks)
+ log_dict["lmk"] = self.cfg.w.landmark * E_lmk
+ result_dict.update(_result_dict)
+
+ if stage is None or isinstance(self.cfg.pipeline[stage], PhotometricStageConfig):
+ if self.cfg.w.photo is not None:
+ verts_n = self.repeat_n_times(verts, num_cameras)
+ rast_dict = self.rasterize_flame(
+ sample, verts_n, self.flame.faces, train_mode=True
+ )
+
+ photo_energy_func = self.compute_photometric_energy
+ E_photo, _result_dict = photo_energy_func(
+ sample,
+ verts,
+ faces,
+ albedos,
+ rast_dict,
+ step_i,
+ stage,
+ include_keyframes,
+ )
+ result_dict.update(_result_dict)
+ log_dict["photo"] = self.cfg.w.photo * E_photo
+
+ if stage is not None:
+ _log_dict = self.compute_regularization_energy(
+ result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage
+ )
+ log_dict.update(_log_dict)
+
+ E_total = torch.stack([v for k, v in log_dict.items()]).sum()
+ log_dict["total"] = E_total
+
+ return E_total, log_dict, verts, faces, lmks, albedos, result_dict
+
+ @staticmethod
+ def to_batch(x, indices):
+ return torch.stack([x[i] for i in indices])
+
+ @staticmethod
+ def repeat_n_times(x: torch.Tensor, n: int):
+ """Expand a tensor from shape [F, ...] to [F*n, ...]"""
+ return x.unsqueeze(1).repeat_interleave(n, dim=1).reshape(-1, *x.shape[1:])
+
+ @torch.no_grad()
+ def log_scalars(
+ self,
+ log_dict,
+ frame_idx,
+ session: Literal["train", "eval"] = "train",
+ stage=None,
+ frame_step=None,
+ # step_in_stage=None,
+ ):
+ """
+ Logs scalars in log_dict to tensorboard and self.logger
+ :param log_dict:
+ :param frame_idx:
+ :param step_i:
+ :return:
+ """
+
+ if not self.calibrated and stage is not None and 'cam' in self.cfg.pipeline[stage].optimizable_params:
+ log_dict["focal_length"] = self.focal_length.squeeze(0)
+
+ log_msg = ""
+
+ if session == "train":
+ global_step = self.global_step
+ else:
+ global_step = frame_idx
+
+ for k, v in log_dict.items():
+ if not k.startswith("decay"):
+ log_msg += "{}: {:.4f} ".format(k, v)
+ if self.tb_writer is not None:
+ self.tb_writer.add_scalar(f"{session}/{k}", v, global_step)
+
+ if session == "train":
+ assert stage is not None
+ if frame_step is not None:
+ msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {frame_step}: "
+ else:
+ msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {self.global_step}: "
+ elif session == "eval":
+ msg_prefix = f"[{session}] frame {frame_idx}: "
+ self.logger.info(msg_prefix + log_msg)
+
+ def save_obj_with_texture(self, vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path):
+ # Save the texture image
+ torchvision.utils.save_image(albedos.squeeze(0), texture_path)
+
+ # Create the MTL file
+ with open(mtl_path, 'w') as f:
+ f.write(get_mtl_content(texture_path.name))
+
+ # Create the obj file
+ with open(obj_path, 'w') as f:
+ f.write(get_obj_content(vertices, faces, uv_coordinates, uv_indices, mtl_path.name))
+
+ def async_func(func):
+ """Decorator to run a function asynchronously"""
+ def wrapper(*args, **kwargs):
+ self = args[0]
+ if self.cfg.async_func:
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs)
+ thread.start()
+ else:
+ func(*args, **kwargs)
+ return wrapper
+
+ @torch.no_grad()
+ @async_func
+ def log_media(
+ self,
+ verts: torch.tensor,
+ faces: torch.tensor,
+ lmks: torch.tensor,
+ albedos: torch.tensor,
+ output_dict: dict,
+ sample: dict,
+ frame_idx: int,
+ session: str,
+ stage: Optional[str]=None,
+ frame_step: int=None,
+ epoch=None,
+ ):
+ """
+ Logs current tracking visualization to tensorboard
+ :param verts:
+ :param lmks:
+ :param sample:
+ :param frame_idx:
+ :param frame_step:
+ :param show_lmks:
+ :param show_overlay:
+ :return:
+ """
+ tic = time.time()
+ prepare_output_path = partial(
+ self.prepare_output_path,
+ session=session,
+ frame_idx=frame_idx,
+ stage=stage,
+ step=frame_step,
+ epoch=epoch,
+ )
+
+ """images"""
+ if not self.cfg.w.always_enable_jawline_landmarks and stage is not None:
+ disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks']
+ else:
+ disable_jawline_landmarks = False
+ img = self.visualize_tracking(verts, lmks, albedos, output_dict, sample, disable_jawline_landmarks=disable_jawline_landmarks)
+ img_path = prepare_output_path(folder_name="image_grid", file_type=self.cfg.log.image_format)
+ torchvision.utils.save_image(img, img_path)
+
+ """meshes"""
+ texture_path = prepare_output_path(folder_name="mesh", file_type=self.cfg.log.image_format)
+ mtl_path = prepare_output_path(folder_name="mesh", file_type="mtl")
+ obj_path = prepare_output_path(folder_name="mesh", file_type="obj")
+
+ vertices = verts.squeeze(0).detach().cpu().numpy()
+ faces = faces.detach().cpu().numpy()
+ uv_coordinates = self.flame.verts_uvs.cpu().numpy()
+ uv_indices = self.flame.textures_idx.cpu().numpy()
+ self.save_obj_with_texture(vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path)
+ """"""
+
+ toc = time.time() - tic
+ if stage is not None:
+ msg_prefix = f"[{session}-{stage}] frame {frame_idx}"
+ else:
+ msg_prefix = f"[{session}] frame {frame_idx}"
+ if frame_step is not None:
+ msg_prefix += f" step {frame_step}"
+ self.logger.info(f"{msg_prefix}: Logging media took {toc:.2f}s")
+
+ @torch.no_grad()
+ def visualize_tracking(
+ self,
+ verts,
+ lmks,
+ albedos,
+ output_dict,
+ sample,
+ return_imgs_seperately=False,
+ disable_jawline_landmarks=False,
+ ):
+ """
+ Visualizes the tracking result
+ """
+ if len(self.cfg.log.view_indices) > 0:
+ view_indices = torch.tensor(self.cfg.log.view_indices)
+ else:
+ num_views = sample["rgb"].shape[0]
+ if num_views > 1:
+ step = (num_views - 1) // (self.cfg.log.max_num_views - 1)
+ view_indices = torch.arange(0, num_views, step=step)
+ else:
+ view_indices = torch.tensor([0])
+ num_views_log = len(view_indices)
+
+ imgs = []
+
+ # rgb
+ gt_rgb = output_dict["gt_rgb"][view_indices].cpu()
+ transfm = torchvision.transforms.Resize(gt_rgb.shape[-2:])
+ imgs += [img[None] for img in gt_rgb]
+
+ if "pred_rgb" in output_dict:
+ pred_rgb = transfm(output_dict["pred_rgb"][view_indices].cpu())
+ pred_rgb = torch.clip(pred_rgb, min=0, max=1)
+ imgs += [img[None] for img in pred_rgb]
+
+ if "error_rgb" in output_dict:
+ error_rgb = transfm(output_dict["error_rgb"][view_indices].cpu())
+ error_rgb = error_rgb.mean(dim=1) / 2 + 0.5
+ cmap = cm.get_cmap("seismic")
+ error_rgb = cmap(error_rgb.cpu())
+ error_rgb = torch.from_numpy(error_rgb[..., :3]).to(gt_rgb).permute(0, 3, 1, 2)
+ imgs += [img[None] for img in error_rgb]
+
+ # cluster id
+ if "cid" in output_dict:
+ cid = transfm(output_dict["cid"][view_indices].cpu())
+ cid = cid / cid.max()
+ cid = cid.expand(-1, 3, -1, -1).clone()
+
+ pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
+ bg = pred_alpha == 0
+ cid[bg] = 1
+ imgs += [img[None] for img in cid]
+
+ # albedo
+ if "albedo" in output_dict:
+ albedo = transfm(output_dict["albedo"][view_indices].cpu())
+ albedo = torch.clip(albedo, min=0, max=1)
+
+ pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
+ bg = pred_alpha == 0
+ albedo[bg] = 1
+ imgs += [img[None] for img in albedo]
+
+ # normal
+ if "normal" in output_dict:
+ normal = transfm(output_dict["normal"][view_indices].cpu())
+ normal = torch.clip(normal/2+0.5, min=0, max=1)
+ imgs += [img[None] for img in normal]
+
+ # diffuse
+ diffuse = None
+ if self.cfg.render.lighting_type != 'constant' and "diffuse" in output_dict:
+ diffuse = transfm(output_dict["diffuse"][view_indices].cpu())
+ diffuse = torch.clip(diffuse, min=0, max=1)
+ imgs += [img[None] for img in diffuse]
+
+ # aa
+ if "aa" in output_dict:
+ aa = transfm(output_dict["aa"][view_indices].cpu())
+ aa = torch.clip(aa, min=0, max=1)
+ imgs += [img[None] for img in aa]
+
+ # alpha
+ if "gt_alpha" in output_dict:
+ gt_alpha = transfm(output_dict["gt_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
+ imgs += [img[None] for img in gt_alpha]
+
+ if "pred_alpha" in output_dict:
+ pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
+ color_alpha = torch.tensor([0.2, 0.5, 1])[None, :, None, None]
+ fg_mask = (pred_alpha > 0).float()
+ if diffuse is not None:
+ fg_mask *= diffuse
+ w = 0.7
+ overlay_alpha = fg_mask * (w * color_alpha * pred_alpha + (1-w) * gt_rgb) \
+ + (1 - fg_mask) * gt_rgb
+ imgs += [img[None] for img in overlay_alpha]
+
+ if "error_alpha" in output_dict:
+ error_alpha = transfm(output_dict["error_alpha"][view_indices].cpu())
+ error_alpha = error_alpha.mean(dim=1) / 2 + 0.5
+ cmap = cm.get_cmap("seismic")
+ error_alpha = cmap(error_alpha.cpu())
+ error_alpha = (
+ torch.from_numpy(error_alpha[..., :3]).to(gt_rgb).permute(0, 3, 1, 2)
+ )
+ imgs += [img[None] for img in error_alpha]
+ else:
+ error_alpha = None
+
+ # landmark
+ vis_lmk = self.visualize_landmarks(gt_rgb, output_dict, view_indices, disable_jawline_landmarks)
+ if vis_lmk is not None:
+ imgs += [img[None] for img in vis_lmk]
+ # ----------------
+ num_types = len(imgs) // len(view_indices)
+
+ if return_imgs_seperately:
+ return imgs
+ else:
+ if self.cfg.log.stack_views_in_rows:
+ imgs = [imgs[j * num_views_log + i] for i in range(num_views_log) for j in range(num_types)]
+ imgs = torch.cat(imgs, dim=0).cpu()
+ return torchvision.utils.make_grid(imgs, nrow=num_types)
+ else:
+ imgs = torch.cat(imgs, dim=0).cpu()
+ return torchvision.utils.make_grid(imgs, nrow=num_views_log)
+
+ @torch.no_grad()
+ def visualize_landmarks(self, gt_rgb, output_dict, view_indices=torch.tensor([0]), disable_jawline_landmarks=False):
+ h, w = gt_rgb.shape[-2:]
+ unit = h / 750
+ wh = torch.tensor([[[w, h]]])
+ vis_lmk = None
+ if "gt_lmk2d" in output_dict:
+ gt_lmk2d = (output_dict['gt_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh
+ if disable_jawline_landmarks:
+ gt_lmk2d = gt_lmk2d[:, 17:68]
+ else:
+ gt_lmk2d = gt_lmk2d[:, :68]
+ vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk
+ for i in range(len(view_indices)):
+ vis_lmk[i] = plot_landmarks_2d(
+ vis_lmk[i].clone(),
+ gt_lmk2d[[i]],
+ colors="green",
+ unit=unit,
+ input_float=True,
+ ).to(vis_lmk[i])
+ if "pred_lmk2d" in output_dict:
+ pred_lmk2d = (output_dict['pred_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh
+ if disable_jawline_landmarks:
+ pred_lmk2d = pred_lmk2d[:, 17:68]
+ else:
+ pred_lmk2d = pred_lmk2d[:, :68]
+ vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk
+ for i in range(len(view_indices)):
+ vis_lmk[i] = plot_landmarks_2d(
+ vis_lmk[i].clone(),
+ pred_lmk2d[[i]],
+ colors="red",
+ unit=unit,
+ input_float=True,
+ ).to(vis_lmk[i])
+ return vis_lmk
+
+ @torch.no_grad()
+ def evaluate(self, make_visualization=True, epoch=0):
+ # always save parameters before evaluation
+ self.save_result(epoch=epoch)
+
+ self.logger.info("Started Evaluation")
+ # vid_frames = []
+ photo_loss = []
+ for frame_idx in range(self.n_timesteps):
+
+ sample = self.get_current_frame(frame_idx, include_keyframes=False)
+ self.clear_cache()
+ self.fill_cam_params_into_sample(sample)
+ (
+ E_total,
+ log_dict,
+ verts,
+ faces,
+ lmks,
+ albedos,
+ output_dict,
+ ) = self.compute_energy(sample, frame_idx)
+
+ self.log_scalars(log_dict, frame_idx, session="eval")
+ photo_loss.append(log_dict["photo"].item())
+
+ if make_visualization:
+ self.log_media(
+ verts,
+ faces,
+ lmks,
+ albedos,
+ output_dict,
+ sample,
+ frame_idx,
+ session="eval",
+ epoch=epoch,
+ )
+
+ self.tb_writer.add_scalar(f"eval_mean/photo", np.mean(photo_loss), epoch)
+
+ def prepare_output_path(self, session, frame_idx, folder_name, file_type, stage=None, step=None, epoch=None):
+ if epoch is not None:
+ output_folder = self.out_dir / f'{session}_{epoch}' / folder_name
+ else:
+ output_folder = self.out_dir / session / folder_name
+ os.makedirs(output_folder, exist_ok=True)
+
+ if stage is not None:
+ assert step is not None
+ fname = "frame_{:05d}_{:03d}_{}.{}".format(frame_idx, step, stage, file_type)
+ else:
+ fname = "frame_{:05d}.{}".format(frame_idx, file_type)
+ return output_folder / fname
+
+ def save_result(self, fname=None, epoch=None):
+ """
+ Saves tracked/optimized flame parameters.
+ :return:
+ """
+ # save parameters
+ keys = [
+ "rotation",
+ "translation",
+ "neck_pose",
+ "jaw_pose",
+ "eyes_pose",
+ "shape",
+ "expr",
+ "timestep_id",
+ "n_processed_frames",
+ ]
+ values = [
+ self.rotation,
+ self.translation,
+ self.neck_pose,
+ self.jaw_pose,
+ self.eyes_pose,
+ self.shape,
+ self.expr,
+ np.array(self.dataset.timestep_ids),
+ self.frame_idx,
+ ]
+ if not self.calibrated:
+ keys += ["focal_length"]
+ values += [self.focal_length]
+
+ if not self.cfg.model.tex_painted:
+ keys += ["tex"]
+ values += [self.tex_pca]
+
+ if self.cfg.model.tex_extra:
+ keys += ["tex_extra"]
+ values += [self.tex_extra]
+
+ if self.lights is not None:
+ keys += ["lights"]
+ values += [self.lights]
+
+ if self.cfg.model.use_static_offset:
+ keys += ["static_offset"]
+ values += [self.static_offset]
+
+ if self.cfg.model.use_dynamic_offset:
+ keys += ["dynamic_offset"]
+ values += [self.dynamic_offset]
+
+ export_dict = {}
+ for k, v in zip(keys, values):
+ if not isinstance(v, np.ndarray):
+ if isinstance(v, list):
+ v = torch.stack(v)
+ if isinstance(v, torch.Tensor):
+ v = v.detach().cpu().numpy()
+ export_dict[k] = v
+
+ export_dict["image_size"] = np.array(self.image_size)
+
+ fname = fname if fname is not None else "tracked_flame_params"
+ if epoch is not None:
+ fname = f"{fname}_{epoch}"
+ np.savez(self.out_dir / f'{fname}.npz', **export_dict)
+
+
+class GlobalTracker(FlameTracker):
+ def __init__(self, cfg: BaseTrackingConfig):
+ super().__init__(cfg)
+
+ self.calibrated = cfg.data.calibrated
+
+ # logging
+ out_dir = cfg.exp.output_folder / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ out_dir.mkdir(parents=True,exist_ok=True)
+
+ self.frame_idx = self.cfg.begin_frame_idx
+ self.out_dir = out_dir
+ self.tb_writer = SummaryWriter(self.out_dir)
+
+ self.log_interval_scalar = self.cfg.log.interval_scalar
+ self.log_interval_media = self.cfg.log.interval_media
+
+ config_yaml_path = out_dir / 'config.yml'
+ config_yaml_path.write_text(yaml.dump(cfg), "utf8")
+ print(tyro.to_yaml(cfg))
+
+ self.logger = get_logger(__name__, root=True, log_dir=out_dir)
+
+ # data
+ self.dataset = import_module(cfg.data._target)(
+ cfg=cfg.data,
+ img_to_tensor=True,
+ batchify_all_views=True, # important to optimized all views together
+ )
+ # FlameTracker expects all views of a frame in a batch, which is undertaken by the
+ # dataset. Therefore batching is disabled for the dataloader
+
+ self.image_size = self.dataset[0]["rgb"].shape[-2:]
+ self.n_timesteps = len(self.dataset)
+
+ # parameters
+ self.init_params()
+
+ if self.cfg.model.flame_params_path is not None:
+ self.load_from_tracked_flame_params(self.cfg.model.flame_params_path)
+
+ def init_params(self):
+ train_tensors = []
+
+ # flame model params
+ self.shape = torch.zeros(self.cfg.model.n_shape).to(self.device)
+ self.expr = torch.zeros(self.n_timesteps, self.cfg.model.n_expr).to(self.device)
+
+ # joint axis angles
+ self.neck_pose = torch.zeros(self.n_timesteps, 3).to(self.device)
+ self.jaw_pose = torch.zeros(self.n_timesteps, 3).to(self.device)
+ self.eyes_pose = torch.zeros(self.n_timesteps, 6).to(self.device)
+
+ # rigid pose
+ self.translation = torch.zeros(self.n_timesteps, 3).to(self.device)
+ self.rotation = torch.zeros(self.n_timesteps, 3).to(self.device)
+
+ # texture and lighting params
+ self.tex_pca = torch.zeros(self.cfg.model.n_tex).to(self.device)
+ if self.cfg.model.tex_extra:
+ res = self.cfg.model.tex_resolution
+ self.tex_extra = torch.zeros(3, res, res).to(self.device)
+
+ if self.cfg.render.lighting_type == 'SH':
+ self.lights_uniform = torch.zeros(9, 3).to(self.device)
+ self.lights_uniform[0] = torch.tensor([np.sqrt(4 * np.pi)]).expand(3).float().to(self.device)
+ self.lights = self.lights_uniform.clone()
+ else:
+ self.lights = None
+
+ train_tensors += (
+ [self.shape, self.translation, self.rotation, self.neck_pose, self.jaw_pose, self.eyes_pose, self.expr,]
+ )
+
+ if not self.cfg.model.tex_painted:
+ train_tensors += [self.tex_pca]
+ if self.cfg.model.tex_extra:
+ train_tensors += [self.tex_extra]
+
+ if self.lights is not None:
+ train_tensors += [self.lights]
+
+ if self.cfg.model.use_static_offset:
+ self.static_offset = torch.zeros(1, self.flame.v_template.shape[0], 3).to(self.device)
+ train_tensors += [self.static_offset]
+ else:
+ self.static_offset = None
+
+ if self.cfg.model.use_dynamic_offset:
+ self.dynamic_offset = torch.zeros(self.n_timesteps, self.flame.v_template.shape[0], 3).to(self.device)
+ train_tensors += self.dynamic_offset
+ else:
+ self.dynamic_offset = None
+
+ # camera definition
+ if not self.calibrated:
+ # K contains focal length and principle point
+ self.focal_length = torch.tensor([1.5]).to(self.device)
+ self.RT = torch.eye(3, 4).to(self.device)
+ self.RT[2, 3] = -1 # (0, 0, -1) in w2c corresponds to (0, 0, 1) in c2w
+ train_tensors += [self.focal_length]
+
+ for t in train_tensors:
+ t.requires_grad = True
+
+ def optimize(self):
+ """
+ Optimizes flame parameters on all frames of the dataset with random rampling
+ :return:
+ """
+ self.global_step = 0
+
+ # first initialize frame either from calibration or previous frame
+ # with torch.no_grad():
+ # self.initialize_frame(frame_idx)
+
+ # sequential optimization of timesteps
+ self.logger.info(f"Start sequential tracking FLAME in {self.n_timesteps} frames")
+ dataloader = DataLoader(self.dataset, batch_size=None, shuffle=False, num_workers=0)
+ for sample in dataloader:
+ timestep = sample["timestep_index"][0].item()
+ if timestep == 0:
+ self.optimize_stage('lmk_init_rigid', sample)
+ self.optimize_stage('lmk_init_all', sample)
+ if self.cfg.exp.photometric:
+ self.optimize_stage('rgb_init_texture', sample)
+ self.optimize_stage('rgb_init_all', sample)
+ if self.cfg.model.use_static_offset:
+ self.optimize_stage('rgb_init_offset', sample)
+
+ if self.cfg.exp.photometric:
+ self.optimize_stage('rgb_sequential_tracking', sample)
+ else:
+ self.optimize_stage('lmk_sequential_tracking', sample)
+ self.initialize_next_timtestep(timestep)
+
+ self.evaluate(make_visualization=False, epoch=0)
+
+ self.logger.info(f"Start global optimization of all frames")
+ # global optimization with random sampling
+ dataloader = DataLoader(self.dataset, batch_size=None, shuffle=True, num_workers=0)
+ if self.cfg.exp.photometric:
+ self.optimize_stage(stage='rgb_global_tracking', dataloader=dataloader, lr_scale=0.1)
+ else:
+ self.optimize_stage(stage='lmk_global_tracking', dataloader=dataloader, lr_scale=0.1)
+
+ self.logger.info("All done.")
+
+ def optimize_stage(
+ self,
+ stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_texture', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'],
+ sample = None,
+ dataloader = None,
+ lr_scale = 1.0,
+ ):
+ params = self.get_train_parameters(stage)
+ optimizer = self.configure_optimizer(params, lr_scale=lr_scale)
+
+ if sample is not None:
+ num_steps = self.cfg.pipeline[stage].num_steps
+ for step_i in range(num_steps):
+ self.optimize_iter(sample, optimizer, stage)
+ else:
+ assert dataloader is not None
+ num_epochs = self.cfg.pipeline[stage].num_epochs
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
+ for epoch_i in range(num_epochs):
+ self.logger.info(f"EPOCH {epoch_i+1} / {num_epochs}")
+ for step_i, sample in enumerate(dataloader):
+ self.optimize_iter(sample, optimizer, stage)
+ scheduler.step()
+
+ if (epoch_i + 1) % 10 == 0:
+ self.evaluate(make_visualization=True, epoch=epoch_i+1)
+
+ def optimize_iter(self, sample, optimizer, stage):
+ # compute loss and update parameters
+ self.clear_cache()
+
+ timestep_index = sample["timestep_index"][0]
+ self.fill_cam_params_into_sample(sample)
+ (
+ E_total,
+ log_dict,
+ verts,
+ faces,
+ lmks,
+ albedos,
+ output_dict,
+ ) = self.compute_energy(
+ sample, frame_idx=timestep_index, stage=stage,
+ )
+ optimizer.zero_grad()
+ E_total.backward()
+ optimizer.step()
+
+ # log energy terms and visualize
+ if (self.global_step+1) % self.log_interval_scalar == 0:
+ self.log_scalars(
+ log_dict,
+ timestep_index,
+ session="train",
+ stage=stage,
+ frame_step=self.global_step,
+ )
+
+ if (self.global_step+1) % self.log_interval_media == 0:
+ self.log_media(
+ verts,
+ faces,
+ lmks,
+ albedos,
+ output_dict,
+ sample,
+ timestep_index,
+ session="train",
+ stage=stage,
+ frame_step=self.global_step,
+ )
+ del verts, faces, lmks, albedos, output_dict
+ self.global_step += 1
+
+
+ def get_train_parameters(
+ self, stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'],
+ ):
+ """
+ Collects the parameters to be optimized for the current frame
+ :return: dict of parameters
+ """
+ self.opt_dict = defaultdict(bool) # dict to keep track of which parameters are optimized
+ for p in self.cfg.pipeline[stage].optimizable_params:
+ self.opt_dict[p] = True
+
+ params = defaultdict(list) # dict to collect parameters to be optimized
+
+ # shared properties
+ if self.opt_dict["cam"] and not self.calibrated:
+ params["cam"] = [self.focal_length]
+
+ if self.opt_dict["shape"]:
+ params["shape"] = [self.shape]
+
+ if self.opt_dict["texture"]:
+ if not self.cfg.model.tex_painted:
+ params["tex"] = [self.tex_pca]
+ if self.cfg.model.tex_extra:
+ params["tex_extra"] = [self.tex_extra]
+
+ if self.opt_dict["static_offset"] and self.cfg.model.use_static_offset:
+ params["static_offset"] = [self.static_offset]
+
+ if self.opt_dict["lights"] and self.lights is not None:
+ params["lights"] = [self.lights]
+
+ # per-frame properties
+ if self.opt_dict["pose"]:
+ params["translation"].append(self.translation)
+ params["rotation"].append(self.rotation)
+
+ if self.opt_dict["joints"]:
+ params["eyes"].append(self.eyes_pose)
+ params["neck"].append(self.neck_pose)
+ params["jaw"].append(self.jaw_pose)
+
+ if self.opt_dict["expr"]:
+ params["expr"].append(self.expr)
+
+ if self.opt_dict["dynamic_offset"] and self.cfg.model.use_dynamic_offset:
+ params["dynamic_offset"].append(self.dynamic_offset)
+
+ return params
+
+ def initialize_next_timtestep(self, timestep):
+ if timestep < self.n_timesteps - 1:
+ self.translation[timestep + 1].data.copy_(self.translation[timestep])
+ self.rotation[timestep + 1].data.copy_(self.rotation[timestep])
+ self.neck_pose[timestep + 1].data.copy_(self.neck_pose[timestep])
+ self.jaw_pose[timestep + 1].data.copy_(self.jaw_pose[timestep])
+ self.eyes_pose[timestep + 1].data.copy_(self.eyes_pose[timestep])
+ self.expr[timestep + 1].data.copy_(self.expr[timestep])
+ if self.cfg.model.use_dynamic_offset:
+ self.dynamic_offset[timestep + 1].data.copy_(self.dynamic_offset[timestep])
diff --git a/vhap/track.py b/vhap/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f773089614e20b36471086fc62d740b4c04076f
--- /dev/null
+++ b/vhap/track.py
@@ -0,0 +1,21 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import tyro
+
+from vhap.config.base import BaseTrackingConfig
+from vhap.model.tracker import GlobalTracker
+
+
+if __name__ == "__main__":
+ tyro.extras.set_accent_color("bright_yellow")
+ cfg = tyro.cli(BaseTrackingConfig)
+
+ tracker = GlobalTracker(cfg)
+ tracker.optimize()
diff --git a/vhap/track_nersemble.py b/vhap/track_nersemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..774736d0b5d871f41f07b0d141934451779544ca
--- /dev/null
+++ b/vhap/track_nersemble.py
@@ -0,0 +1,21 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import tyro
+
+from vhap.config.nersemble import NersembleTrackingConfig
+from vhap.model.tracker import GlobalTracker
+
+
+if __name__ == "__main__":
+ tyro.extras.set_accent_color("bright_yellow")
+ cfg = tyro.cli(NersembleTrackingConfig)
+
+ tracker = GlobalTracker(cfg)
+ tracker.optimize()
diff --git a/vhap/util/camera.py b/vhap/util/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..610aca0d3546133ab1a8e9e8b14fda63ed4085ec
--- /dev/null
+++ b/vhap/util/camera.py
@@ -0,0 +1,223 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from typing import Tuple, Literal
+import torch
+import torch.nn.functional as F
+import math
+import numpy as np
+from scipy.spatial.transform import Rotation
+
+
+def align_cameras_to_axes(
+ R: torch.Tensor,
+ T: torch.Tensor,
+ target_convention: Literal["opengl", "opencv"] = None,
+):
+ """align the averaged axes of cameras with the world axes.
+
+ Args:
+ R: rotation matrix (N, 3, 3)
+ T: translation vector (N, 3)
+ """
+ # The column vectors of R are the basis vectors of each camera.
+ # We construct new bases by taking the mean directions of axes, then use Gram-Schmidt
+ # process to make them orthonormal
+ bases_c2w = gram_schmidt_orthogonalization(R.mean(0))
+ if target_convention == "opengl":
+ bases_c2w[:, [1, 2]] *= -1 # flip y and z axes
+ elif target_convention == "opencv":
+ pass
+ bases_w2c = bases_c2w.t()
+
+ # convert the camera poses into the new coordinate system
+ R = bases_w2c[None, ...] @ R
+ T = bases_w2c[None, ...] @ T
+ return R, T
+
+
+def convert_camera_convention(camera_convention_conversion: str, R: torch.Tensor, K: torch.Tensor, H: int, W: int):
+ if camera_convention_conversion is not None:
+ if camera_convention_conversion == "opencv->opengl":
+ R[:, :3, [1, 2]] *= -1
+ # flip y of the principal point
+ K[..., 1, 2] = H - K[..., 1, 2]
+ elif camera_convention_conversion == "opencv->pytorch3d":
+ R[:, :3, [0, 1]] *= -1
+ # flip x and y of the principal point
+ K[..., 0, 2] = W - K[..., 0, 2]
+ K[..., 1, 2] = H - K[..., 1, 2]
+ elif camera_convention_conversion == "opengl->pytorch3d":
+ R[:, :3, [0, 2]] *= -1
+ # flip x of the principal point
+ K[..., 0, 2] = W - K[..., 0, 2]
+ else:
+ raise ValueError(
+ f"Unknown camera coordinate conversion: {camera_convention_conversion}."
+ )
+ return R, K
+
+
+def gram_schmidt_orthogonalization(M: torch.tensor):
+ """conducting Gram-Schmidt process to transform column vectors into orthogonal bases
+
+ Args:
+ M: An matrix (num_rows, num_cols)
+ Return:
+ M: An matrix with orthonormal column vectors (num_rows, num_cols)
+ """
+ num_rows, num_cols = M.shape
+ for c in range(1, num_cols):
+ M[:, [c - 1, c]] = F.normalize(M[:, [c - 1, c]], p=2, dim=0)
+ M[:, [c]] -= M[:, :c] @ (M[:, :c].T @ M[:, [c]])
+
+ M[:, -1] = F.normalize(M[:, -1], p=2, dim=0)
+ return M
+
+
+def projection_from_intrinsics(K: np.ndarray, image_size: Tuple[int], near: float=0.01, far:float=10, flip_y: bool=False, z_sign=-1):
+ """
+ Transform points from camera space (x: right, y: up, z: out) to clip space (x: right, y: down, z: in)
+ Args:
+ K: Intrinsic matrix, (N, 3, 3)
+ K = [[
+ [fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1],
+ ]
+ ]
+ image_size: (height, width)
+ Output:
+ proj = [[
+ [2*fx/w, 0.0, (w - 2*cx)/w, 0.0 ],
+ [0.0, 2*fy/h, (h - 2*cy)/h, 0.0 ],
+ [0.0, 0.0, z_sign*(far+near) / (far-near), -2*far*near / (far-near)],
+ [0.0, 0.0, z_sign, 0.0 ]
+ ]
+ ]
+ """
+
+ B = K.shape[0]
+ h, w = image_size
+
+ if K.shape[-2:] == (3, 3):
+ fx = K[..., 0, 0]
+ fy = K[..., 1, 1]
+ cx = K[..., 0, 2]
+ cy = K[..., 1, 2]
+ elif K.shape[-1] == 4:
+ # fx, fy, cx, cy = K[..., [0, 1, 2, 3]].split(1, dim=-1)
+ fx = K[..., [0]]
+ fy = K[..., [1]]
+ cx = K[..., [2]]
+ cy = K[..., [3]]
+ else:
+ raise ValueError(f"Expected K to be (N, 3, 3) or (N, 4) but got: {K.shape}")
+
+ proj = np.zeros([B, 4, 4])
+ proj[:, 0, 0] = fx * 2 / w
+ proj[:, 1, 1] = fy * 2 / h
+ proj[:, 0, 2] = (w - 2 * cx) / w
+ proj[:, 1, 2] = (h - 2 * cy) / h
+ proj[:, 2, 2] = z_sign * (far+near) / (far-near)
+ proj[:, 2, 3] = -2*far*near / (far-near)
+ proj[:, 3, 2] = z_sign
+
+ if flip_y:
+ proj[:, 1, 1] *= -1
+ return proj
+
+
+class OrbitCamera:
+ def __init__(self, W, H, r=2, fovy=60, znear=1e-8, zfar=10, convention: Literal["opengl", "opencv"]="opengl"):
+ self.image_width = W
+ self.image_height = H
+ self.radius_default = r
+ self.fovy_default = fovy
+ self.znear = znear
+ self.zfar = zfar
+ self.convention = convention
+
+ self.up = np.array([0, 1, 0], dtype=np.float32)
+ self.reset()
+
+ def reset(self):
+ """ The internal state of the camera is based on the OpenGL convention, but
+ properties are converted to the target convention when queried.
+ """
+ self.rot = Rotation.from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # OpenGL convention
+ self.look_at = np.array([0, 0, 0], dtype=np.float32) # look at this point
+ self.radius = self.radius_default # camera distance from center
+ self.fovy = self.fovy_default
+ if self.convention == "opencv":
+ self.z_sign = 1
+ self.y_sign = 1
+ elif self.convention == "opengl":
+ self.z_sign = -1
+ self.y_sign = -1
+ else:
+ raise ValueError(f"Unknown convention: {self.convention}")
+
+ @property
+ def fovx(self):
+ return self.fovy / self.image_height * self.image_width
+
+ @property
+ def intrinsics(self):
+ focal = self.image_height / (2 * np.tan(np.radians(self.fovy) / 2))
+ return np.array([focal, focal, self.image_width // 2, self.image_height // 2])
+
+ @property
+ def projection_matrix(self):
+ return projection_from_intrinsics(self.intrinsics[None], (self.image_height, self.image_width), self.znear, self.zfar, z_sign=self.z_sign)[0]
+
+ @property
+ def world_view_transform(self):
+ return np.linalg.inv(self.pose) # world2cam
+
+ @property
+ def full_proj_transform(self):
+ return self.projection_matrix @ self.world_view_transform
+
+ @property
+ def pose(self):
+ # first move camera to radius
+ pose = np.eye(4, dtype=np.float32)
+ pose[2, 3] += self.radius
+
+ # rotate
+ rot = np.eye(4, dtype=np.float32)
+ rot[:3, :3] = self.rot.as_matrix()
+ pose = rot @ pose
+
+ # translate
+ pose[:3, 3] -= self.look_at
+
+ if self.convention == "opencv":
+ pose[:, [1, 2]] *= -1
+ elif self.convention == "opengl":
+ pass
+ else:
+ raise ValueError(f"Unknown convention: {self.convention}")
+ return pose
+
+ def orbit(self, dx, dy):
+ # rotate along camera up/side axis!
+ side = self.rot.as_matrix()[:3, 0]
+ rotvec_x = self.up * np.radians(-0.3 * dx)
+ rotvec_y = side * np.radians(-0.3 * dy)
+ self.rot = Rotation.from_rotvec(rotvec_x) * Rotation.from_rotvec(rotvec_y) * self.rot
+
+ def scale(self, delta):
+ self.radius *= 1.1 ** (-delta)
+
+ def pan(self, dx, dy, dz=0):
+ # pan in camera coordinate system (careful on the sensitivity!)
+ d = np.array([dx, -dy, dz]) # the y axis is flipped
+ self.look_at += 2 * self.rot.as_matrix()[:3, :3] @ d * self.radius / self.image_height * math.tan(np.radians(self.fovy) / 2)
diff --git a/vhap/util/landmark_detector_fa.py b/vhap/util/landmark_detector_fa.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63011e4591a6a7cde2331308debe69618676d97
--- /dev/null
+++ b/vhap/util/landmark_detector_fa.py
@@ -0,0 +1,309 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from vhap.util.log import get_logger
+
+from typing import Literal
+from tqdm import tqdm
+
+import face_alignment
+import numpy as np
+import matplotlib.path as mpltPath
+
+from fdlite import (
+ FaceDetection,
+ FaceLandmark,
+ face_detection_to_roi,
+ IrisLandmark,
+ iris_roi_from_face_landmarks,
+)
+
+logger = get_logger(__name__)
+
+
+class LandmarkDetectorFA:
+
+ IMAGE_FILE_NAME = "image_0000.png"
+ LMK_FILE_NAME = "keypoints_static_0000.json"
+
+ def __init__(
+ self,
+ face_detector:Literal["sfd", "blazeface"]="sfd",
+ ):
+ """
+ Creates dataset_path where all results are stored
+ :param video_path: path to video file
+ :param dataset_path: path to results directory
+ """
+
+ logger.info("Initialize FaceAlignment module...")
+ # 68 facial landmark detector
+ self.fa = face_alignment.FaceAlignment(
+ face_alignment.LandmarksType.TWO_HALF_D,
+ face_detector=face_detector,
+ flip_input=True,
+ device="cuda"
+ )
+
+ def detect_single_image(self, img):
+ bbox = self.fa.face_detector.detect_from_image(img)
+
+ if len(bbox) == 0:
+ lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable
+
+ else:
+ if len(bbox) > 1:
+ # if multiple boxes detected, use the one with highest confidence
+ bbox = [bbox[np.argmax(np.array(bbox)[:, -1])]]
+
+ lmks = self.fa.get_landmarks_from_image(img, detected_faces=bbox)[0]
+ lmks = np.concatenate([lmks, np.ones_like(lmks[:, :1])], axis=1)
+
+ if (lmks[:, :2] == -1).sum() > 0:
+ lmks[:, 2:] = 0.0
+ else:
+ lmks[:, 2:] = 1.0
+
+ h, w = img.shape[:2]
+ lmks[:, 0] /= w
+ lmks[:, 1] /= h
+ bbox[0][[0, 2]] /= w
+ bbox[0][[1, 3]] /= h
+ return bbox, lmks
+
+ def detect_dataset(self, dataloader):
+ """
+ Annotates each frame with 68 facial landmarks
+ :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes
+ """
+ landmarks = {}
+ bboxes = {}
+
+ logger.info("Begin annotating landmarks...")
+ for item in tqdm(dataloader):
+ timestep_id = item["timestep_id"][0]
+ camera_id = item["camera_id"][0]
+ scale_factor = item["scale_factor"][0]
+
+ logger.info(
+ f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}"
+ )
+ img = item["rgb"][0].numpy()
+
+ bbox, lmks = self.detect_single_image(img)
+
+ if len(bbox) == 0:
+ logger.error(
+ f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1."
+ )
+
+ if camera_id not in landmarks:
+ landmarks[camera_id] = {}
+ if camera_id not in bboxes:
+ bboxes[camera_id] = {}
+ landmarks[camera_id][timestep_id] = lmks
+ bboxes[camera_id][timestep_id] = bbox[0] if len(bbox) > 0 else np.zeros(5) - 1
+ return landmarks, bboxes
+
+ def annotate_iris_landmarks(self, dataloader):
+ """
+ Annotates each frame with 2 iris landmarks
+ :return: dict mapping frame number to landmarks numpy array
+ """
+
+ # iris detector
+ detect_faces = FaceDetection()
+ detect_face_landmarks = FaceLandmark()
+ detect_iris_landmarks = IrisLandmark()
+
+ landmarks = {}
+
+ for item in tqdm(dataloader):
+ timestep_id = item["timestep_id"][0]
+ camera_id = item["camera_id"][0]
+ scale_factor = item["scale_factor"][0]
+ if timestep_id not in landmarks:
+ landmarks[timestep_id] = {}
+ logger.info(
+ f"Annotate iris landmarks for timestep: {timestep_id}, camera: {camera_id}"
+ )
+
+ img = item["rgb"][0].numpy()
+
+ height, width = img.shape[:2]
+ img_size = (width, height)
+
+ face_detections = detect_faces(img)
+ if len(face_detections) != 1:
+ logger.error("Empty iris landmarks (type 1)")
+ landmarks[timestep_id][camera_id] = None
+ else:
+ for face_detection in face_detections:
+ try:
+ face_roi = face_detection_to_roi(face_detection, img_size)
+ except ValueError:
+ logger.error("Empty iris landmarks (type 2)")
+ landmarks[timestep_id][camera_id] = None
+ break
+
+ face_landmarks = detect_face_landmarks(img, face_roi)
+ if len(face_landmarks) == 0:
+ logger.error("Empty iris landmarks (type 3)")
+ landmarks[timestep_id][camera_id] = None
+ break
+
+ iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size)
+
+ if len(iris_rois) != 2:
+ logger.error("Empty iris landmarks (type 4)")
+ landmarks[timestep_id][camera_id] = None
+ break
+
+ lmks = []
+ for iris_roi in iris_rois[::-1]:
+ try:
+ iris_landmarks = detect_iris_landmarks(img, iris_roi).iris[
+ 0:1
+ ]
+ except np.linalg.LinAlgError:
+ logger.error("Failed to get iris landmarks")
+ landmarks[timestep_id][camera_id] = None
+ break
+
+ for landmark in iris_landmarks:
+ lmks.append([landmark.x * width, landmark.y * height, 1.0])
+
+ lmks = np.array(lmks, dtype=np.float32)
+
+ h, w = img.shape[:2]
+ lmks[:, 0] /= w
+ lmks[:, 1] /= h
+
+ landmarks[timestep_id][camera_id] = lmks
+
+ return landmarks
+
+ def iris_consistency(self, lm_iris, lm_eye):
+ """
+ Checks if landmarks for eye and iris are consistent
+ :param lm_iris:
+ :param lm_eye:
+ :return:
+ """
+ lm_iris = lm_iris[:, :2]
+ lm_eye = lm_eye[:, :2]
+
+ polygon_eye = mpltPath.Path(lm_eye)
+ valid = polygon_eye.contains_points(lm_iris)
+
+ return valid[0]
+
+ def annotate_landmarks(self, dataloader, add_iris=False):
+ """
+ Annotates each frame with landmarks for face and iris. Assumes frames have been extracted
+ :param add_iris:
+ :return:
+ """
+ lmks_face, bboxes_faces = self.detect_dataset(dataloader)
+
+ if add_iris:
+ lmks_iris = self.annotate_iris_landmarks(dataloader)
+
+ # check conistency of iris landmarks and facial keypoints
+ for camera_id, lmk_face_camera in lmks_face.items():
+ for timestep_id in lmk_face_camera.keys():
+
+ discard_iris_lmks = False
+ bboxes_face_i = bboxes_faces[camera_id][timestep_id]
+ if bboxes_face_i is not None:
+ lmks_face_i = lmks_face[camera_id][timestep_id]
+ lmks_iris_i = lmks_iris[camera_id][timestep_id]
+ if lmks_iris_i is not None:
+
+ # validate iris landmarks
+ left_face = lmks_face_i[36:42]
+ right_face = lmks_face_i[42:48]
+
+ right_iris = lmks_iris_i[:1]
+ left_iris = lmks_iris_i[1:]
+
+ if not (
+ self.iris_consistency(left_iris, left_face)
+ and self.iris_consistency(right_iris, right_face)
+ ):
+ logger.error(
+ f"Inconsistent iris landmarks for timestep: {timestep_id}, camera: {camera_id}"
+ )
+ discard_iris_lmks = True
+ else:
+ logger.error(
+ f"No iris landmarks detected for timestep: {timestep_id}, camera: {camera_id}"
+ )
+ discard_iris_lmks = True
+
+ else:
+ logger.error(
+ f"Discarding iris landmarks because no face landmark is available for timestep: {timestep_id}, camera: {camera_id}"
+ )
+ discard_iris_lmks = True
+
+ if discard_iris_lmks:
+ lmks_iris[timestep_id][camera_id] = (
+ np.zeros([2, 3]) - 1
+ ) # set to -1 for inconsistent iris landmarks
+
+ # construct final json
+ for camera_id, lmk_face_camera in lmks_face.items():
+ bounding_box = []
+ face_landmark_2d = []
+ iris_landmark_2d = []
+ for timestep_id in lmk_face_camera.keys():
+ bounding_box.append(bboxes_faces[camera_id][timestep_id][None])
+ face_landmark_2d.append(lmks_face[camera_id][timestep_id][None])
+
+ if add_iris:
+ iris_landmark_2d.append(lmks_iris[camera_id][timestep_id][None])
+
+ lmk_dict = {
+ "bounding_box": bounding_box,
+ "face_landmark_2d": face_landmark_2d,
+ }
+ if len(iris_landmark_2d) > 0:
+ lmk_dict["iris_landmark_2d"] = iris_landmark_2d
+
+ for k, v in lmk_dict.items():
+ if len(v) > 0:
+ lmk_dict[k] = np.concatenate(v, axis=0)
+ out_path = dataloader.dataset.get_property_path(
+ "landmark2d/face-alignment", camera_id=camera_id
+ )
+ logger.info(f"Saving landmarks to: {out_path}")
+ if not out_path.parent.exists():
+ out_path.parent.mkdir(parents=True)
+ np.savez(out_path, **lmk_dict)
+
+
+if __name__ == "__main__":
+ import tyro
+ from tqdm import tqdm
+ from torch.utils.data import DataLoader
+ from vhap.config.base import DataConfig, import_module
+
+ cfg = tyro.cli(DataConfig)
+ dataset = import_module(cfg._target)(
+ cfg=cfg,
+ img_to_tensor=False,
+ batchify_all_views=True,
+ )
+ dataset.items = dataset.items[:2]
+
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
+
+ detector = LandmarkDetectorFA()
+ detector.annotate_landmarks(dataloader)
diff --git a/vhap/util/landmark_detector_star.py b/vhap/util/landmark_detector_star.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb719f8333009c484f7df385e3f69af474a49b6
--- /dev/null
+++ b/vhap/util/landmark_detector_star.py
@@ -0,0 +1,351 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from tqdm import tqdm
+import copy
+import argparse
+import torch
+import math
+import cv2
+import numpy as np
+import dlib
+
+from star.lib import utility
+from star.asset import predictor_path, model_path
+
+from vhap.util.log import get_logger
+logger = get_logger(__name__)
+
+
+class GetCropMatrix():
+ """
+ from_shape -> transform_matrix
+ """
+
+ def __init__(self, image_size, target_face_scale, align_corners=False):
+ self.image_size = image_size
+ self.target_face_scale = target_face_scale
+ self.align_corners = align_corners
+
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
+ cosv = math.cos(angle)
+ sinv = math.sin(angle)
+
+ fx, fy = from_center
+ tx, ty = to_center
+
+ acos = scale * cosv
+ asin = scale * sinv
+
+ a0 = acos
+ a1 = -asin
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
+
+ b0 = asin
+ b1 = acos
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
+
+ rot_scale_m = np.array([
+ [a0, a1, a2],
+ [b0, b1, b2],
+ [0.0, 0.0, 1.0]
+ ], np.float32)
+ return rot_scale_m
+
+ def process(self, scale, center_w, center_h):
+ if self.align_corners:
+ to_w, to_h = self.image_size - 1, self.image_size - 1
+ else:
+ to_w, to_h = self.image_size, self.image_size
+
+ rot_mu = 0
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
+ shift_xy_mu = (0, 0)
+ matrix = self._compose_rotate_and_scale(
+ rot_mu, scale_mu, shift_xy_mu,
+ from_center=[center_w, center_h],
+ to_center=[to_w / 2.0, to_h / 2.0])
+ return matrix
+
+
+class TransformPerspective():
+ """
+ image, matrix3x3 -> transformed_image
+ """
+
+ def __init__(self, image_size):
+ self.image_size = image_size
+
+ def process(self, image, matrix):
+ return cv2.warpPerspective(
+ image, matrix, dsize=(self.image_size, self.image_size),
+ flags=cv2.INTER_LINEAR, borderValue=0)
+
+
+class TransformPoints2D():
+ """
+ points (nx2), matrix (3x3) -> points (nx2)
+ """
+
+ def process(self, srcPoints, matrix):
+ # nx3
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
+ desPoints = desPoints @ np.transpose(matrix) # nx3
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
+ return desPoints.astype(srcPoints.dtype)
+
+
+class Alignment:
+ def __init__(self, args, model_path, dl_framework, device_ids):
+ self.input_size = 256
+ self.target_face_scale = 1.0
+ self.dl_framework = dl_framework
+
+ # model
+ if self.dl_framework == "pytorch":
+ # conf
+ self.config = utility.get_config(args)
+ self.config.device_id = device_ids[0]
+ # set environment
+ utility.set_environment(self.config)
+ self.config.init_instance()
+ if self.config.logger is not None:
+ self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
+ self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
+
+ net = utility.get_net(self.config)
+ if device_ids == [-1]:
+ checkpoint = torch.load(model_path, map_location="cpu")
+ else:
+ checkpoint = torch.load(model_path)
+ net.load_state_dict(checkpoint["net"])
+ net = net.to(self.config.device_id)
+ net.eval()
+ self.alignment = net
+ else:
+ assert False
+
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
+ align_corners=True)
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
+ self.transformPoints2D = TransformPoints2D()
+
+ def norm_points(self, points, align_corners=False):
+ if align_corners:
+ # [0, SIZE-1] -> [-1, +1]
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
+ else:
+ # [-0.5, SIZE-0.5] -> [-1, +1]
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
+
+ def denorm_points(self, points, align_corners=False):
+ if align_corners:
+ # [-1, +1] -> [0, SIZE-1]
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
+ else:
+ # [-1, +1] -> [-0.5, SIZE-0.5]
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
+
+ def preprocess(self, image, scale, center_w, center_h):
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
+ input_tensor = self.transformPerspective.process(image, matrix)
+ input_tensor = input_tensor[np.newaxis, :]
+
+ input_tensor = torch.from_numpy(input_tensor)
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
+ input_tensor = input_tensor.to(self.config.device_id)
+ return input_tensor, matrix
+
+ def postprocess(self, srcPoints, coeff):
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
+ # matrix^(-1) * src = dst
+ # src = matrix * dst
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
+ for i in range(srcPoints.shape[0]):
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
+ return dstPoints
+
+ def analyze(self, image, scale, center_w, center_h):
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
+
+ if self.dl_framework == "pytorch":
+ with torch.no_grad():
+ output = self.alignment(input_tensor)
+ landmarks = output[-1][0]
+ else:
+ assert False
+
+ landmarks = self.denorm_points(landmarks)
+ landmarks = landmarks.data.cpu().numpy()[0]
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
+
+ return landmarks
+
+
+def draw_pts(img, pts, mode="pts", shift=4, color=(0, 255, 0), radius=1, thickness=1, save_path=None, dif=0,
+ scale=0.3, concat=False, ):
+ img_draw = copy.deepcopy(img)
+ for cnt, p in enumerate(pts):
+ if mode == "index":
+ cv2.putText(img_draw, str(cnt), (int(float(p[0] + dif)), int(float(p[1] + dif))), cv2.FONT_HERSHEY_SIMPLEX,
+ scale, color, thickness)
+ elif mode == 'pts':
+ if len(img_draw.shape) > 2:
+ # 此处来回切换是因为opencv的bug
+ img_draw = cv2.cvtColor(img_draw, cv2.COLOR_BGR2RGB)
+ img_draw = cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR)
+ cv2.circle(img_draw, (int(p[0] * (1 << shift)), int(p[1] * (1 << shift))), radius << shift, color, -1,
+ cv2.LINE_AA, shift=shift)
+ else:
+ raise NotImplementedError
+ if concat:
+ img_draw = np.concatenate((img, img_draw), axis=1)
+ if save_path is not None:
+ cv2.imwrite(save_path, img_draw)
+ return img_draw
+
+
+class LandmarkDetectorSTAR:
+ def __init__(
+ self,
+ ):
+ self.detector = dlib.get_frontal_face_detector()
+ self.shape_predictor = dlib.shape_predictor(predictor_path)
+
+ # facial landmark detector
+ args = argparse.Namespace()
+ args.config_name = 'alignment'
+ # could be downloaded here: https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view
+ # model_path = '/path/to/WFLW_STARLoss_NME_4_02_FR_2_32_AUC_0_605.pkl'
+ device_ids = '0'
+ device_ids = list(map(int, device_ids.split(",")))
+ self.alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
+
+ def detect_single_image(self, img):
+ bbox = self.detector(img, 1)
+
+ if len(bbox) == 0:
+ bbox = np.zeros(5) - 1
+ lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable
+ else:
+ face = self.shape_predictor(img, bbox[0])
+ shape = []
+ for i in range(68):
+ x = face.part(i).x
+ y = face.part(i).y
+ shape.append((x, y))
+ shape = np.array(shape)
+ x1, x2 = shape[:, 0].min(), shape[:, 0].max()
+ y1, y2 = shape[:, 1].min(), shape[:, 1].max()
+ scale = min(x2 - x1, y2 - y1) / 200 * 1.05
+ center_w = (x2 + x1) / 2
+ center_h = (y2 + y1) / 2
+
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
+ lmks = self.alignment.analyze(img, scale, center_w, center_h)
+
+ h, w = img.shape[:2]
+
+ lmks = np.concatenate([lmks, np.ones([lmks.shape[0], 1])], axis=1).astype(np.float32) # (x, y, 1)
+ lmks[:, 0] /= w
+ lmks[:, 1] /= h
+
+ bbox = np.array([bbox[0].left(), bbox[0].top(), bbox[0].right(), bbox[0].bottom(), 1.]).astype(np.float32) # (x1, y1, x2, y2, score)
+ bbox[[0, 2]] /= w
+ bbox[[1, 3]] /= h
+
+ return bbox, lmks
+
+ def detect_dataset(self, dataloader):
+ """
+ Annotates each frame with 68 facial landmarks
+ :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes
+ """
+ logger.info("Initialize Landmark Detector (STAR)...")
+ # 68 facial landmark detector
+
+ landmarks = {}
+ bboxes = {}
+
+ logger.info("Begin annotating landmarks...")
+ for item in tqdm(dataloader):
+ timestep_id = item["timestep_id"][0]
+ camera_id = item["camera_id"][0]
+
+ logger.info(
+ f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}"
+ )
+ img = item["rgb"][0].numpy()
+
+ bbox, lmks = self.detect_single_image(img)
+ if len(bbox) == 0:
+ logger.error(
+ f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1."
+ )
+
+ if camera_id not in landmarks:
+ landmarks[camera_id] = {}
+ if camera_id not in bboxes:
+ bboxes[camera_id] = {}
+ landmarks[camera_id][timestep_id] = lmks
+ bboxes[camera_id][timestep_id] = bbox
+ return landmarks, bboxes
+
+ def annotate_landmarks(self, dataloader):
+ """
+ Annotates each frame with landmarks for face and iris. Assumes frames have been extracted
+ :return:
+ """
+ lmks_face, bboxes_faces = self.detect_dataset(dataloader)
+
+ # construct final json
+ for camera_id, lmk_face_camera in lmks_face.items():
+ bounding_box = []
+ face_landmark_2d = []
+ for timestep_id in lmk_face_camera.keys():
+ bounding_box.append(bboxes_faces[camera_id][timestep_id][None])
+ face_landmark_2d.append(lmks_face[camera_id][timestep_id][None])
+
+ lmk_dict = {
+ "bounding_box": bounding_box,
+ "face_landmark_2d": face_landmark_2d,
+ }
+
+ for k, v in lmk_dict.items():
+ if len(v) > 0:
+ lmk_dict[k] = np.concatenate(v, axis=0)
+ out_path = dataloader.dataset.get_property_path(
+ "landmark2d/STAR", camera_id=camera_id
+ )
+ logger.info(f"Saving landmarks to: {out_path}")
+ if not out_path.parent.exists():
+ out_path.parent.mkdir(parents=True)
+ np.savez(out_path, **lmk_dict)
+
+
+if __name__ == "__main__":
+ import tyro
+ from tqdm import tqdm
+ from torch.utils.data import DataLoader
+ from vhap.config.base import DataConfig, import_module
+
+ cfg = tyro.cli(DataConfig)
+ dataset = import_module(cfg._target)(
+ cfg=cfg,
+ img_to_tensor=False,
+ batchify_all_views=True,
+ )
+ dataset.items = dataset.items[:2]
+
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
+
+ detector = LandmarkDetectorSTAR()
+ detector.annotate_landmarks(dataloader)
diff --git a/vhap/util/log.py b/vhap/util/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..078dab10edef4bc4d40672c17d3f6517cae6c9c9
--- /dev/null
+++ b/vhap/util/log.py
@@ -0,0 +1,88 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import logging
+import sys
+from datetime import datetime
+import atexit
+from pathlib import Path
+
+
+def _colored(msg, color):
+ colors = {'red': '\033[91m', 'green': '\033[92m', 'yellow': '\033[93m', 'normal': '\033[0m'}
+ return colors[color] + msg + colors["normal"]
+
+
+class ColorFormatter(logging.Formatter):
+ """
+ Class to make command line log entries more appealing
+ Inspired by https://github.com/facebookresearch/detectron2
+ """
+
+ def formatMessage(self, record):
+ """
+ Print warnings yellow and errors red
+ :param record:
+ :return:
+ """
+ log = super().formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = _colored("WARNING", "yellow")
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = _colored("ERROR", "red")
+ else:
+ return log
+ return prefix + " " + log
+
+
+def get_logger(name, level=logging.DEBUG, root=False, log_dir=None):
+ """
+ Replaces the standard library logging.getLogger call in order to make some configuration
+ for all loggers.
+ :param name: pass the __name__ variable
+ :param level: the desired log level
+ :param root: call only once in the program
+ :param log_dir: if root is set to True, this defines the directory where a log file is going
+ to be created that contains all logging output
+ :return: the logger object
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ if root:
+ # create handler for console
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(level)
+ formatter = ColorFormatter(_colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S")
+ console_handler.setFormatter(formatter)
+ logger.addHandler(console_handler)
+ logger.propagate = False # otherwise root logger prints things again
+
+ if log_dir is not None:
+ # add handler to log to a file
+ log_dir = Path(log_dir)
+ if not log_dir.exists():
+ logger.info(f"Logging directory {log_dir} does not exist and will be created")
+ log_dir.mkdir(parents=True)
+ timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
+ log_file = log_dir / f"{timestamp}.log"
+
+ # open stream and make sure it will be closed
+ stream = log_file.open(mode="w")
+ atexit.register(stream.close)
+
+ formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
+ datefmt="%m/%d %H:%M:%S")
+ file_handler = logging.StreamHandler(stream)
+ file_handler.setLevel(level)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
diff --git a/vhap/util/mesh.py b/vhap/util/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f6fa80798e91497eebecdd24015929ae2b64ec
--- /dev/null
+++ b/vhap/util/mesh.py
@@ -0,0 +1,73 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import torch
+
+
+def get_mtl_content(tex_fname):
+ return f'newmtl Material\nmap_Kd {tex_fname}\n'
+
+def get_obj_content(vertices, faces, uv_coordinates=None, uv_indices=None, mtl_fname=None):
+ obj = ('# Generated with multi-view-head-tracker\n')
+
+ if mtl_fname is not None:
+ obj += f'mtllib {mtl_fname}\n'
+ obj += 'usemtl Material\n'
+
+ # Write the vertices
+ for vertex in vertices:
+ obj += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n"
+
+ # Write the UV coordinates
+ if uv_coordinates is not None:
+ for uv in uv_coordinates:
+ obj += f"vt {uv[0]} {uv[1]}\n"
+
+ # Write the faces with UV indices
+ if uv_indices is not None:
+ for face, uv_indices in zip(faces, uv_indices):
+ obj += f"f {face[0]+1}/{uv_indices[0]+1} {face[1]+1}/{uv_indices[1]+1} {face[2]+1}/{uv_indices[2]+1}\n"
+ else:
+ for face in faces:
+ obj += f"f {face[0]+1} {face[1]+1} {face[2]+1}\n"
+ return obj
+
+def normalize_image_points(u, v, resolution):
+ """
+ normalizes u, v coordinates from [0 ,image_size] to [-1, 1]
+ :param u:
+ :param v:
+ :param resolution:
+ :return:
+ """
+ u = 2 * (u - resolution[1] / 2.0) / resolution[1]
+ v = 2 * (v - resolution[0] / 2.0) / resolution[0]
+ return u, v
+
+
+def face_vertices(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of faces, 3, 3]
+ """
+ assert vertices.ndimension() == 3
+ assert faces.ndimension() == 3
+ assert vertices.shape[0] == faces.shape[0]
+ assert vertices.shape[2] == 3
+ assert faces.shape[2] == 3
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
diff --git a/vhap/util/render_nvdiffrast.py b/vhap/util/render_nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb8c120a797461eb191f27b12309139278ee0e2
--- /dev/null
+++ b/vhap/util/render_nvdiffrast.py
@@ -0,0 +1,599 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+from typing import Tuple, Literal, Optional
+# from pytorch3d.structures.meshes import Meshes
+import nvdiffrast.torch as dr
+import torch.nn.functional as F
+import torch
+import numpy as np
+from vhap.util import vector_ops as V
+
+
+def get_SH_shading(normals, sh_coefficients, sh_const):
+ """
+ :param normals: shape N, H, W, K, 3
+ :param sh_coefficients: shape N, 9, 3
+ :return:
+ """
+
+ N = normals
+
+ # compute sh basis function values of shape [N, H, W, K, 9]
+ sh = torch.stack(
+ [
+ N[..., 0] * 0.0 + 1.0,
+ N[..., 0],
+ N[..., 1],
+ N[..., 2],
+ N[..., 0] * N[..., 1],
+ N[..., 0] * N[..., 2],
+ N[..., 1] * N[..., 2],
+ N[..., 0] ** 2 - N[..., 1] ** 2,
+ 3 * (N[..., 2] ** 2) - 1,
+ ],
+ dim=-1,
+ )
+ sh = sh * sh_const[None, None, None, :].to(sh.device)
+
+ # shape [N, H, W, K, 9, 1]
+ sh = sh[..., None]
+
+ # shape [N, H, W, K, 9, 3]
+ sh_coefficients = sh_coefficients[:, None, None, :, :]
+
+ # shape after linear combination [N, H, W, K, 3]
+ shading = torch.sum(sh_coefficients * sh, dim=3)
+ return shading
+
+
+class NVDiffRenderer(torch.nn.Module):
+ def __init__(
+ self,
+ use_opengl: bool = False,
+ lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'front',
+ lighting_space: Literal['camera', 'world'] = 'world',
+ disturb_rate_fg: Optional[float] = 0.5,
+ disturb_rate_bg: Optional[float] = 0.5,
+ fid2cid: Optional[torch.Tensor] = None,
+ ):
+ super().__init__()
+ self.backend = 'nvdiffrast'
+ self.lighting_type = lighting_type
+ self.lighting_space = lighting_space
+ self.disturb_rate_fg = disturb_rate_fg
+ self.disturb_rate_bg = disturb_rate_bg
+ self.glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
+ self.fragment_cache = None
+
+ if fid2cid is not None:
+ fid2cid = F.pad(fid2cid, [1, 0], value=0) # for nvdiffrast, fid==0 means background pixels
+ self.register_buffer("fid2cid", fid2cid, persistent=False)
+
+ # constant factor of first three bands of spherical harmonics
+ pi = np.pi
+ sh_const = torch.tensor(
+ [
+ 1 / np.sqrt(4 * pi),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
+ (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
+ ],
+ dtype=torch.float32,
+ )
+ self.register_buffer("sh_const", sh_const, persistent=False)
+
+ def clear_cache(self):
+ self.fragment_cache = None
+
+ def mvp_from_camera_param(self, RT, K, image_size):
+ # projection matrix
+ proj = self.projection_from_intrinsics(K, image_size)
+
+ # Modelview and modelview + projection matrices.
+ if RT.shape[-2] == 3:
+ mv = torch.nn.functional.pad(RT, [0, 0, 0, 1])
+ mv[..., 3, 3] = 1
+ elif RT.shape[-2] == 4:
+ mv = RT
+ mvp = torch.bmm(proj, mv)
+ return mvp
+
+ def projection_from_intrinsics(self, K: torch.Tensor, image_size: Tuple[int], near: float=0.1, far:float=10):
+ """
+ Transform points from camera space (x: right, y: up, z: out) to clip space (x: right, y: down, z: in)
+ Args:
+ K: Intrinsic matrix, (N, 3, 3)
+ K = [[
+ [fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1],
+ ]
+ ]
+ image_size: (height, width)
+ Output:
+ proj = [[
+ [2*fx/w, 0.0, (w - 2*cx)/w, 0.0 ],
+ [0.0, 2*fy/h, (h - 2*cy)/h, 0.0 ],
+ [0.0, 0.0, -(far+near) / (far-near), -2*far*near / (far-near)],
+ [0.0, 0.0, -1.0, 0.0 ]
+ ]
+ ]
+ """
+
+ B = K.shape[0]
+ h, w = image_size
+
+ if K.shape[-2:] == (3, 3):
+ fx = K[..., 0, 0]
+ fy = K[..., 1, 1]
+ cx = K[..., 0, 2]
+ cy = K[..., 1, 2]
+ elif K.shape[-1] == 4:
+ fx, fy, cx, cy = K[..., [0, 1, 2, 3]].split(1, dim=-1)
+ else:
+ raise ValueError(f"Expected K to be (N, 3, 3) or (N, 4) but got: {K.shape}")
+
+ proj = torch.zeros([B, 4, 4], device=K.device)
+ proj[:, 0, 0] = fx * 2 / w
+ proj[:, 1, 1] = fy * 2 / h
+ proj[:, 0, 2] = (w - 2 * cx) / w
+ proj[:, 1, 2] = (h - 2 * cy) / h
+ proj[:, 2, 2] = -(far+near) / (far-near)
+ proj[:, 2, 3] = -2*far*near / (far-near)
+ proj[:, 3, 2] = -1
+ return proj
+
+ def world_to_camera(self, vtx, RT):
+ """Transform vertex positions from the world space to the camera space"""
+ RT = torch.from_numpy(RT).cuda() if isinstance(RT, np.ndarray) else RT
+ if RT.shape[-2] == 3:
+ mv = torch.nn.functional.pad(RT, [0, 0, 0, 1])
+ mv[..., 3, 3] = 1
+ elif RT.shape[-2] == 4:
+ mv = RT
+
+ # (x,y,z) -> (x',y',z',w)
+ assert vtx.shape[-1] in [3, 4]
+ if vtx.shape[-1] == 3:
+ posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1)
+ elif vtx.shape[-1] == 4:
+ posw = vtx
+ else:
+ raise ValueError(f"Expected 3D or 4D points but got: {vtx.shape[-1]}")
+ return torch.bmm(posw, RT.transpose(-1, -2))
+
+ def camera_to_clip(self, vtx, K, image_size):
+ """Transform vertex positions from the camera space to the clip space"""
+ K = torch.from_numpy(K).cuda() if isinstance(K, np.ndarray) else K
+ proj = self.projection_from_intrinsics(K, image_size)
+
+ # (x,y,z) -> (x',y',z',w)
+ assert vtx.shape[-1] in [3, 4]
+ if vtx.shape[-1] == 3:
+ posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1)
+ elif vtx.shape[-1] == 4:
+ posw = vtx
+ else:
+ raise ValueError(f"Expected 3D or 4D points but got: {vtx.shape[-1]}")
+ return torch.bmm(posw, proj.transpose(-1, -2))
+
+ def world_to_clip(self, vtx, RT, K, image_size):
+ """Transform vertex positions from the world space to the clip space"""
+ mvp = self.mvp_from_camera_param(RT, K, image_size)
+
+ mvp = torch.from_numpy(mvp).cuda() if isinstance(mvp, np.ndarray) else mvp
+ # (x,y,z) -> (x',y',z',w)
+ posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1)
+ return torch.bmm(posw, mvp.transpose(-1, -2))
+
+ def world_to_ndc(self, vtx, RT, K, image_size, flip_y=False):
+ """Transform vertex positions from the world space to the NDC space"""
+ verts_clip = self.world_to_clip(vtx, RT, K, image_size)
+ verts_ndc = verts_clip[:, :, :3] / verts_clip[:, :, 3:]
+ if flip_y:
+ verts_ndc[:, :, 1] *= -1
+ return verts_ndc
+
+ def rasterize(self, verts, faces, RT, K, image_size, use_cache=False, require_grad=False):
+ """
+ Rasterizes meshes using a standard rasterization approach
+ :param meshes:
+ :param cameras:
+ :param image_size:
+ :return: fragments:
+ screen_coords: N x H x W x 2 with x, y values following pytorch3ds NDC-coord system convention
+ top left = +1, +1 ; bottom_right = -1, -1
+ """
+ # v_normals = self.compute_v_normals(verts, faces)
+ # vertices and faces
+ verts_camera = self.world_to_camera(verts, RT)
+ verts_clip = self.camera_to_clip(verts_camera, K, image_size)
+ tri = faces.int()
+ rast_out, rast_out_db = self.rasterize_fragments(verts_clip, tri, image_size, use_cache, require_grad)
+ rast_dict = {
+ "rast_out": rast_out,
+ "rast_out_db": rast_out_db,
+ "verts": verts,
+ "verts_camera": verts_camera[..., :3],
+ "verts_clip": verts_clip,
+ }
+
+ # if not require_grad:
+ # verts_ndc = verts_clip[:, :, :3] / verts_clip[:, :, 3:]
+ # screen_coords = self.compute_screen_coords(rast_out, verts_ndc, faces, image_size)
+ # rast_dict["screen_coords"] = screen_coords
+
+ return rast_dict
+
+ def rasterize_fragments(self, verts_clip, tri, image_size, use_cache, require_grad=False):
+ """
+ Either rasterizes meshes or returns cached result
+ """
+
+ if not use_cache or self.fragment_cache is None:
+ if require_grad:
+ rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size)
+ else:
+ with torch.no_grad():
+ rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size)
+ self.fragment_cache = (rast_out, rast_out_db)
+
+ return self.fragment_cache
+
+ def compute_screen_coords(self, rast_out: torch.Tensor, verts:torch.Tensor, faces:torch.Tensor, image_size: Tuple[int]):
+ """ Compute screen coords for visible pixels
+ Args:
+ verts: (N, V, 3), the verts should lie in the ndc space
+ faces: (F, 3)
+ """
+ N = verts.shape[0]
+ F = faces.shape[0]
+ meshes = Meshes(verts, faces[None, ...].expand(N, -1, -1))
+ verts_packed = meshes.verts_packed()
+ faces_packed = meshes.faces_packed()
+ face_verts = verts_packed[faces_packed]
+
+ # NOTE: nvdiffrast shifts face index by +1, and use 0 to flag empty pixel
+ pix2face = rast_out[..., -1:].long() - 1 # (N, H, W, 1)
+ is_visible = pix2face > -1 # (N, H, W, 1)
+ # NOTE: is_visible is computed before packing pix2face to ensure correctness
+ pix2face_packed = pix2face + torch.arange(0, N)[:, None, None, None].to(pix2face) * F
+
+ bary_coords = rast_out[..., :2] # (N, H, W, 2)
+ bary_coords = torch.cat([bary_coords, 1 - bary_coords.sum(dim=-1, keepdim=True)], dim =-1) # (N, H, W, 3)
+
+ visible_faces = pix2face_packed[is_visible] # (sum(is_visible), 3, 3)
+ visible_face_verts = face_verts[visible_faces]
+ visible_bary_coords = bary_coords[is_visible[..., 0]] # (sum(is_visible), 3, 1)
+ # visible_bary_coords = torch.cat([visible_bary_coords, 1 - visible_bary_coords.sum(dim=-1, keepdim=True)], dim =-1)
+
+ visible_surface_point = visible_face_verts * visible_bary_coords[..., None]
+ visible_surface_point = visible_surface_point.sum(dim=1)
+
+ screen_coords = torch.zeros(*pix2face_packed.shape[:3], 2, device=meshes.device)
+ screen_coords[is_visible[..., 0]] = visible_surface_point[:, :2] # now have gradient
+
+ return screen_coords
+
+ def compute_v_normals(self, verts, faces):
+ i0 = faces[..., 0].long()
+ i1 = faces[..., 1].long()
+ i2 = faces[..., 2].long()
+
+ v0 = verts[..., i0, :]
+ v1 = verts[..., i1, :]
+ v2 = verts[..., i2, :]
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+ v_normals = torch.zeros_like(verts)
+ N = verts.shape[0]
+ v_normals.scatter_add_(1, i0[..., None].repeat(N, 1, 3), face_normals)
+ v_normals.scatter_add_(1, i1[..., None].repeat(N, 1, 3), face_normals)
+ v_normals.scatter_add_(1, i2[..., None].repeat(N, 1, 3), face_normals)
+
+ v_normals = torch.where(V.dot(v_normals, v_normals) > 1e-20, v_normals, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
+ v_normals = V.safe_normalize(v_normals)
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(v_normals))
+ return v_normals
+
+ def compute_face_normals(self, verts, faces):
+ i0 = faces[..., 0].long()
+ i1 = faces[..., 1].long()
+ i2 = faces[..., 2].long()
+
+ v0 = verts[..., i0, :]
+ v1 = verts[..., i1, :]
+ v2 = verts[..., i2, :]
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+ face_normals = V.safe_normalize(face_normals)
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(face_normals))
+ return face_normals
+
+ def shade(self, normal, lighting_coeff=None):
+ if self.lighting_type == 'constant':
+ diffuse = torch.ones_like(normal[..., :3])
+ elif self.lighting_type == 'front':
+ # diffuse = torch.clamp(V.dot(normal, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')), 0.0, 1.0)
+ diffuse = V.dot(normal, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
+ mask_backface = diffuse < 0
+ diffuse[mask_backface] = diffuse[mask_backface].abs()*0.3
+ elif self.lighting_type == 'front-range':
+ bias = 0.75
+ diffuse = torch.clamp(V.dot(normal, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + bias, 0.0, 1.0)
+ elif self.lighting_type == 'SH':
+ diffuse = get_SH_shading(normal, lighting_coeff, self.sh_const)
+ else:
+ raise NotImplementedError(f"Unknown lighting type: {self.lighting_type}")
+ return diffuse
+
+ def detach_by_indices(self, x, indices):
+ x = x.clone()
+ x[:, indices] = x[:, indices].detach()
+ return x
+
+ def render_rgba(
+ self, rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color=[1., 1., 1.],
+ align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False,
+ ):
+ """
+ Renders flame RGBA images
+ """
+
+ rast_out = rast_dict["rast_out"]
+ rast_out_db = rast_dict["rast_out_db"]
+ verts = rast_dict["verts"]
+ verts_camera = rast_dict["verts_camera"]
+ verts_clip = rast_dict["verts_clip"]
+ faces = faces.int()
+ faces_uv = faces_uv.int()
+ fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool()
+
+ out_dict = {}
+
+ # ---- vertex attributes ----
+ if self.lighting_space == 'world':
+ v_normal = self.compute_v_normals(verts, faces)
+ elif self.lighting_space == 'camera':
+ v_normal = self.compute_v_normals(verts_camera, faces)
+ else:
+ raise NotImplementedError(f"Unknown lighting space: {self.lighting_space}")
+
+ v_attr = [v_normal]
+
+ v_attr = torch.cat(v_attr, dim=-1)
+ attr, _ = dr.interpolate(v_attr, rast_out, faces)
+ normal = attr[..., :3]
+ normal = V.safe_normalize(normal)
+
+ # ---- uv-space attributes ----
+ texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all')
+ if align_texture_except_fid is not None: # TODO: rethink when shading with normal
+ fid = rast_out[..., -1:].long() # the face index is shifted by +1
+ mask = torch.zeros(faces.shape[0]+1, dtype=torch.bool, device=fid.device)
+ mask[align_texture_except_fid + 1] = True
+ b, h, w = rast_out.shape[:3]
+ rast_mask = torch.gather(mask.reshape(1, 1, 1, -1).expand(b, h, w, -1), 3, fid)
+ texc = torch.where(rast_mask, texc.detach(), texc)
+
+ tex = tex.permute(0, 2, 3, 1).contiguous() # (N, T, T, 4)
+ albedo = dr.texture(tex, texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=None)
+
+ # ---- shading ----
+ diffuse = self.shade(normal, lights)
+ diffuse_detach_normal = self.shade(normal.detach(), lights)
+
+ rgb = albedo * diffuse
+ alpha = fg_mask.float()
+ rgba = torch.cat([rgb, alpha], dim=-1)
+
+ # ---- background ----
+ if isinstance(background_color, list):
+ """Background as a constant color"""
+ rgba_bg = torch.tensor(background_color + [0]).to(rgba).expand_as(rgba) # RGBA
+ elif isinstance(background_color, torch.Tensor):
+ """Background as a image"""
+ rgba_bg = background_color
+ rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) # RGBA
+ else:
+ raise ValueError(f"Unknown background type: {type(background_color)}")
+ rgba_bg = rgba_bg.flip(1) # opengl camera has y-axis up, needs flipping
+
+ rgba = torch.where(fg_mask, rgba, rgba_bg)
+ rgba_orig = rgba
+
+ if enable_disturbance:
+ # ---- color disturbance ----
+ B, H, W, _ = rgba.shape
+ # compute random blending weights based on the disturbance rate
+ if self.disturb_rate_fg is not None:
+ w_fg = (torch.rand_like(rgba[..., :1]) < self.disturb_rate_fg).int()
+ else:
+ w_fg = torch.zeros_like(rgba[..., :1]).int()
+ if self.disturb_rate_bg is not None:
+ w_bg = (torch.rand_like(rgba[..., :1]) < self.disturb_rate_bg).int()
+ else:
+ w_bg = torch.zeros_like(rgba[..., :1]).int()
+
+ # sample pixles from clusters
+ fid = rast_out[..., -1:].long() # the face index is shifted by +1
+ num_clusters = self.fid2cid.max() + 1
+
+ fid2cid = self.fid2cid[None, None, None, :].expand(B, H, W, -1)
+ cid = torch.gather(fid2cid, -1, fid)
+ out_dict['cid'] = cid.flip(1)
+
+ rgba_ = torch.zeros_like(rgba)
+ for i in range(num_clusters):
+ c_rgba = rgba_bg if i == 0 else rgba
+ w = w_bg if i == 0 else w_fg
+
+ c_mask = cid == i
+ c_pixels = c_rgba[c_mask.repeat_interleave(4, dim=-1)].reshape(-1, 4).detach() # NOTE: detach to avoid gradient flow
+
+ if i != 1: # skip #1 indicate faces that are not in any cluster
+ if len(c_pixels) > 0:
+ c_idx = torch.randint(0, len(c_pixels), (B * H * W, ), device=c_pixels.device)
+ c_sample = c_pixels[c_idx].reshape(B, H, W, 4)
+ rgba_ += c_mask * (c_sample * w + c_rgba * (1 - w))
+ else:
+ rgba_ += c_mask * c_rgba
+ rgba = rgba_
+
+ # ---- AA on both RGB and alpha channels ----
+ if align_boundary_except_vid is not None:
+ verts_clip = self.detach_by_indices(verts_clip, align_boundary_except_vid)
+ rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int())
+ aa = ((rgba - rgba_aa) != 0).any(dim=-1, keepdim=True).repeat_interleave(4, dim=-1)
+
+ # rgba_aa = torch.where(aa, rgba_aa, rgba_orig) # keep the original color if not antialiased (commented out due to worse tracking performance)
+
+ # ---- AA only on RGB channels ----
+ # rgb = rgba[..., :3].contiguous()
+ # alpha = rgba[..., 3:]
+ # rgb = dr.antialias(rgb, rast_out, verts_clip, faces.int())
+ # rgba = torch.cat([rgb, alpha], dim=-1)
+
+ out_dict.update({
+ 'albedo': albedo.flip(1),
+ 'normal': normal.flip(1),
+ 'diffuse': diffuse.flip(1),
+ 'diffuse_detach_normal': diffuse_detach_normal.flip(1),
+ 'rgba': rgba_aa.flip(1),
+ 'aa': aa[..., :3].float().flip(1),
+ })
+ return out_dict
+
+ def render_without_texture(
+ self, verts, faces, RT, K, image_size, background_color=[1., 1., 1.],
+ ):
+ """
+ Renders meshes into RGBA images
+ """
+
+ verts_camera_ = self.world_to_camera(verts, RT)
+ verts_camera = verts_camera_[..., :3]
+ verts_clip = self.camera_to_clip(verts_camera_, K, image_size)
+ tri = faces.int()
+ rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size)
+
+ faces = faces.int()
+ fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool()
+ face_id = torch.clamp(rast_out[..., -1:].long() - 1, 0) # (B, W, H, 1)
+ W, H = face_id.shape[1:3]
+
+ face_normals = self.compute_face_normals(verts_camera, faces) # (B, F, 3)
+ face_normals_ = face_normals[:, None, None, :, :].expand(-1, W, H, -1, -1) # (B, 1, 1, F, 3)
+ face_id_ = face_id[:, :, :, None].expand(-1, -1, -1, -1, 3) # (B, W, H, 1, 1)
+ normal = torch.gather(face_normals_, -2, face_id_).squeeze(-2) # (B, W, H, 3)
+
+ albedo = torch.ones_like(normal)
+
+ # ---- shading ----
+ diffuse = self.shade(normal)
+
+ rgb = albedo * diffuse
+ alpha = fg_mask.float()
+ rgba = torch.cat([rgb, alpha], dim=-1)
+
+ # ---- background ----
+ if isinstance(background_color, list) or isinstance(background_color, tuple):
+ """Background as a constant color"""
+ rgba_bg = torch.tensor(list(background_color) + [0]).to(rgba).expand_as(rgba) # RGBA
+ elif isinstance(background_color, torch.Tensor):
+ """Background as a image"""
+ rgba_bg = background_color
+ rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) # RGBA
+ else:
+ raise ValueError(f"Unknown background type: {type(background_color)}")
+ rgba_bg = rgba_bg.flip(1) # opengl camera has y-axis up, needs flipping
+
+ normal = torch.where(fg_mask, normal, rgba_bg[..., :3])
+ diffuse = torch.where(fg_mask, diffuse, rgba_bg[..., :3])
+ rgba = torch.where(fg_mask, rgba, rgba_bg)
+
+ # ---- AA on both RGB and alpha channels ----
+ rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int())
+
+ return {
+ 'albedo': albedo.flip(1),
+ 'normal': normal.flip(1),
+ 'diffuse': diffuse.flip(1),
+ 'rgba': rgba_aa.flip(1),
+ 'verts_clip': verts_clip,
+ }
+
+ def render_v_color(
+ self, verts, v_color, faces, RT, K, image_size, background_color=[1., 1., 1.],
+ ):
+ """
+ Renders meshes into RGBA images
+ """
+
+ verts_camera_ = self.world_to_camera(verts, RT)
+ verts_camera = verts_camera_[..., :3]
+ verts_clip = self.camera_to_clip(verts_camera_, K, image_size)
+ tri = faces.int()
+ rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size)
+
+ faces = faces.int()
+ fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool()
+ face_id = torch.clamp(rast_out[..., -1:].long() - 1, 0) # (B, W, H, 1)
+ W, H = face_id.shape[1:3]
+
+ face_normals = self.compute_face_normals(verts_camera, faces) # (B, F, 3)
+ face_normals_ = face_normals[:, None, None, :, :].expand(-1, W, H, -1, -1) # (B, 1, 1, F, 3)
+ face_id_ = face_id[:, :, :, None].expand(-1, -1, -1, -1, 3) # (B, W, H, 1, 1)
+ normal = torch.gather(face_normals_, -2, face_id_).squeeze(-2) # (B, W, H, 3)
+
+ albedo = torch.ones_like(normal)
+
+ v_attr = [v_color]
+ v_attr = torch.cat(v_attr, dim=-1)
+ attr, _ = dr.interpolate(v_attr, rast_out, faces)
+ albedo = attr[..., :3]
+
+ # ---- shading ----
+ diffuse = self.shade(normal)
+
+ rgb = albedo * diffuse
+ alpha = fg_mask.float()
+ rgba = torch.cat([rgb, alpha], dim=-1)
+
+ # ---- background ----
+ if isinstance(background_color, list) or isinstance(background_color, tuple):
+ """Background as a constant color"""
+ rgba_bg = torch.tensor(list(background_color) + [0]).to(rgba).expand_as(rgba) # RGBA
+ elif isinstance(background_color, torch.Tensor):
+ """Background as a image"""
+ rgba_bg = background_color
+ rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) # RGBA
+ else:
+ raise ValueError(f"Unknown background type: {type(background_color)}")
+ rgba_bg = rgba_bg.flip(1) # opengl camera has y-axis up, needs flipping
+
+ normal = torch.where(fg_mask, normal, rgba_bg[..., :3])
+ diffuse = torch.where(fg_mask, diffuse, rgba_bg[..., :3])
+ rgba = torch.where(fg_mask, rgba, rgba_bg)
+
+ # ---- AA on both RGB and alpha channels ----
+ rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int())
+
+ return {
+ 'albedo': albedo.flip(1),
+ 'normal': normal.flip(1),
+ 'diffuse': diffuse.flip(1),
+ 'rgba': rgba_aa.flip(1),
+ }
diff --git a/vhap/util/render_uvmap.py b/vhap/util/render_uvmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7fbbd7baacc4ba5052aa2453381a24e188952d
--- /dev/null
+++ b/vhap/util/render_uvmap.py
@@ -0,0 +1,86 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import tyro
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import nvdiffrast.torch as dr
+
+from vhap.model.flame import FlameHead
+
+
+FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz"
+
+
+def transform_vt(vt):
+ """Transform uv vertices to clip space"""
+ xy = vt * 2 - 1
+ w = torch.ones([1, vt.shape[-2], 1]).to(vt)
+ z = -w # In the clip spcae of OpenGL, the camera looks at -z
+ xyzw = torch.cat([xy[None, :, :], z, w], axis=-1)
+ return xyzw
+
+def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution):
+ """Render uv map with vertex color"""
+ pos_clip = transform_vt(pos)
+ rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
+
+ color, _ = dr.interpolate(v_color, rast_out, col_idx)
+ color = dr.antialias(color, rast_out, pos_clip, pos_idx)
+ return color
+
+def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None):
+ """Render uv map with texture map"""
+ pos_clip = transform_vt(pos)
+ rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
+
+ if enable_mip:
+ texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all')
+ color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
+ else:
+ texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv)
+ color = dr.texture(tex[None, ...], texc, filter_mode='linear')
+ color = dr.antialias(color, rast_out, pos_clip, pos_idx)
+ return color
+
+
+def main(
+ use_texmap: bool = False,
+ use_opengl: bool = False,
+):
+ n_shape = 300
+ n_expr = 100
+ print("Initialization FLAME model")
+ flame_model = FlameHead(n_shape, n_expr)
+
+ verts_uv = flame_model.verts_uvs.cuda()
+ verts_uv[:, 1] = 1 - verts_uv[:, 1]
+ faces_uv = flame_model.textures_idx.int().cuda()
+
+ # Rasterizer context
+ glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
+
+ h, w = 512, 512
+ resolution = (h, w)
+
+ if use_texmap:
+ tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255
+ rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True)
+ else:
+ v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv)
+ col_idx = faces_uv
+ rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)
+
+ plt.imshow(rgb[0, :, :].cpu())
+ plt.show()
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/vhap/util/vector_ops.py b/vhap/util/vector_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..50db8371fba8b3516360ea6b8cf18e679411d3d5
--- /dev/null
+++ b/vhap/util/vector_ops.py
@@ -0,0 +1,17 @@
+import torch
+
+
+def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return torch.sum(x*y, -1, keepdim=True)
+
+def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
+ return 2*dot(x, n)*n - x
+
+def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+ return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
+
+def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+ return x / length(x, eps)
+
+def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
+ return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
diff --git a/vhap/util/visualization.py b/vhap/util/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e87c14ed0db4c425b4fdeae2ad6cc7e4b53e96
--- /dev/null
+++ b/vhap/util/visualization.py
@@ -0,0 +1,126 @@
+#
+# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
+# property and proprietary rights in and to this software and related documentation.
+# Any commercial use, reproduction, disclosure or distribution of this software and
+# related documentation without an express license agreement from Toyota Motor Europe NV/SA
+# is strictly prohibited.
+#
+
+
+import matplotlib.pyplot as plt
+import torch
+from torchvision.utils import draw_bounding_boxes, draw_keypoints
+
+
+connectivity_face = (
+ [(i, i + 1) for i in list(range(0, 16))]
+ + [(i, i + 1) for i in list(range(17, 21))]
+ + [(i, i + 1) for i in list(range(22, 26))]
+ + [(i, i + 1) for i in list(range(27, 30))]
+ + [(i, i + 1) for i in list(range(31, 35))]
+ + [(i, i + 1) for i in list(range(36, 41))]
+ + [(36, 41)]
+ + [(i, i + 1) for i in list(range(42, 47))]
+ + [(42, 47)]
+ + [(i, i + 1) for i in list(range(48, 59))]
+ + [(48, 59)]
+ + [(i, i + 1) for i in list(range(60, 67))]
+ + [(60, 67)]
+)
+
+
+def plot_landmarks_2d(
+ img: torch.tensor,
+ lmks: torch.tensor,
+ connectivity=None,
+ colors="white",
+ unit=1,
+ input_float=False,
+):
+ if input_float:
+ img = (img * 255).byte()
+
+ img = draw_keypoints(
+ img,
+ lmks,
+ connectivity=connectivity,
+ colors=colors,
+ radius=2 * unit,
+ width=2 * unit,
+ )
+
+ if input_float:
+ img = img.float() / 255
+ return img
+
+
+def blend(a, b, w):
+ return (a * w + b * (1 - w)).byte()
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ from torch.utils.data import DataLoader
+ from matplotlib import pyplot as plt
+
+ from vhap.data.nersemble_dataset import NeRSembleDataset
+
+ parser = ArgumentParser()
+ parser.add_argument("--root_folder", type=str, required=True)
+ parser.add_argument("--subject", type=str, required=True)
+ parser.add_argument("--sequence", type=str, required=True)
+ parser.add_argument("--division", default=None)
+ parser.add_argument("--subset", default=None)
+ parser.add_argument("--scale_factor", type=float, default=1.0)
+ parser.add_argument("--blend_weight", type=float, default=0.6)
+ args = parser.parse_args()
+
+ dataset = NeRSembleDataset(
+ root_folder=args.root_folder,
+ subject=args.subject,
+ sequence=args.sequence,
+ division=args.division,
+ subset=args.subset,
+ n_downsample_rgb=2,
+ scale_factor=args.scale_factor,
+ use_landmark=True,
+ )
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
+
+ for item in dataloader:
+ unit = int(item["scale_factor"][0] * 3) + 1
+
+ rgb = item["rgb"][0].permute(2, 0, 1)
+ vis = rgb
+
+ if "bbox_2d" in item:
+ bbox = item["bbox_2d"][0][:4]
+ tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
+ vis = blend(tmp, vis, args.blend_weight)
+
+ if "lmk2d" in item:
+ face_landmark = item["lmk2d"][0][:, :2]
+ tmp = plot_landmarks_2d(
+ vis,
+ face_landmark[None, ...],
+ connectivity=connectivity_face,
+ colors="white",
+ unit=unit,
+ )
+ vis = blend(tmp, vis, args.blend_weight)
+
+ if "lmk2d_iris" in item:
+ iris_landmark = item["lmk2d_iris"][0][:, :2]
+ tmp = plot_landmarks_2d(
+ vis,
+ iris_landmark[None, ...],
+ colors="blue",
+ unit=unit,
+ )
+ vis = blend(tmp, vis, args.blend_weight)
+
+ vis = vis.permute(1, 2, 0).numpy()
+ plt.imshow(vis)
+ plt.draw()
+ while not plt.waitforbuttonpress(timeout=-1):
+ pass