diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..704ba719fc02c95d9f938ecf5fd5800414360be3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +output +outputs +**__pycache__ +.DS_Store +cache +step1x3d_texture/custom_rasterizer/build +step1x3d_texture/custom_rasterizer/dist +step1x3d_texture/custom_rasterizer/custom_rasterizer.egg-info +step1x3d_texture/differentiable_renderer/build +step1x3d_texture/differentiable_renderer/dist +step1x3d_texture/differentiable_renderer/mesh_processor.egg-info \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..aa55837e5c6cfffc81b0b1d784d9e2ee846d38e0 --- /dev/null +++ b/app.py @@ -0,0 +1,135 @@ +import os +import time +import uuid +import torch +import trimesh +import argparse +import numpy as np +import gradio as gr +from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline +from step1x3d_texture.pipelines.step1x_3d_texture_synthesis_pipeline import ( + Step1X3DTexturePipeline, +) +from step1x3d_texture.utils.shape_post_process import ( + FaceReducer, + DegenerateFaceRemover, +) + + +def generate_func( + input_image_path, guidance_scale, inference_steps, max_facenum, symmetry, edge_type +): + if "Label" in args.geometry_model: + out = geometry_model( + input_image_path, + label={"symmetry": symmetry, "edge_type": edge_type}, + guidance_scale=float(guidance_scale), + octree_resolution=384, + max_facenum=int(max_facenum), + num_inference_steps=int(inference_steps), + ) + else: + out = geometry_model( + input_image_path, + guidance_scale=float(guidance_scale), + num_inference_steps=int(inference_steps), + max_facenum=int(max_facenum), + ) + + save_name = str(uuid.uuid4()) + print(save_name) + geometry_save_path = f"{args.cache_dir}/{save_name}.glb" + geometry_mesh = out.mesh[0] + geometry_mesh.export(geometry_save_path) + + geometry_mesh = DegenerateFaceRemover()(geometry_mesh) + geometry_mesh = FaceReducer()(geometry_mesh) + textured_mesh = texture_model(input_image_path, geometry_mesh) + textured_save_path = f"{args.cache_dir}/{save_name}-textured.glb" + textured_mesh.export(textured_save_path) + + torch.cuda.empty_cache() + print("Generate finish") + return geometry_save_path, textured_save_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--geometry_model", type=str, default="Step1X-3D-Geometry-Label-1300m" + ) + parser.add_argument( + "--texture_model", type=str, default="Step1X-3D-Texture" + ) + parser.add_argument("--cache_dir", type=str, default="cache") + parser.add_argument("--port", type=int, default=7861) + parser.add_argument("--host", type=str, default="0.0.0.0") + args = parser.parse_args() + + os.makedirs(args.cache_dir, exist_ok=True) + + geometry_model = Step1X3DGeometryPipeline.from_pretrained( + "stepfun-ai/Step1X-3D", subfolder=args.geometry_model + ).to("cuda") + + texture_model = Step1X3DTexturePipeline.from_pretrained("stepfun-ai/Step1X-3D", subfolder=args.texture_model) + + with gr.Blocks(title="Step1X-3D demo") as demo: + gr.Markdown("# Step1X-3D") + with gr.Row(): + with gr.Column(scale=2): + input_image = gr.Image( + label="Image", type="filepath", image_mode="RGBA" + ) + guidance_scale = gr.Number(label="Guidance Scale", value="7.5") + inference_steps = gr.Slider( + label="Inferece Steps", minimum=1, maximum=100, value=50 + ) + max_facenum = gr.Number(label="Max Face Num", value="400000") + symmetry = gr.Radio( + choices=["x", "asymmetry"], + label="Symmetry Type", + value="x", + type="value", + ) + edge_type = gr.Radio( + choices=["sharp", "normal", "smooth"], + label="Edge Type", + value="sharp", + type="value", + ) + btn = gr.Button("Start") + with gr.Column(scale=4): + textured_preview = gr.Model3D(label="Textured", height=380) + geometry_preview = gr.Model3D(label="Geometry", height=380) + with gr.Column(scale=1): + gr.Examples( + examples=[ + ["examples/images/000.png"], + ["examples/images/001.png"], + ["examples/images/004.png"], + ["examples/images/008.png"], + ["examples/images/028.png"], + ["examples/images/032.png"], + ["examples/images/061.png"], + ["examples/images/107.png"], + ], + inputs=[input_image], + cache_examples=False, + ) + + btn.click( + generate_func, + inputs=[ + input_image, + guidance_scale, + inference_steps, + max_facenum, + symmetry, + edge_type, + ], + outputs=[geometry_preview, textured_preview], + ) + + demo.launch(server_name=args.host, server_port=args.port) + demo.queue(concurrency_count=3) diff --git a/examples/images/000.png b/examples/images/000.png new file mode 100644 index 0000000000000000000000000000000000000000..7a9cea2ef484d93a8290a3a007396ba822a2b8c4 --- /dev/null +++ b/examples/images/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62284b41c010dd81524c51d12da4369fc458abd955011f59ce395266a02efb5f +size 1542412 diff --git a/examples/images/001.png b/examples/images/001.png new file mode 100644 index 0000000000000000000000000000000000000000..26224646aa9869af01dc37b3d5ae71018cda88da --- /dev/null +++ b/examples/images/001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e93cc2c9850b6ea7cf233ae2f8d96246d86de7fc1d9bf079f2455a47938e946a +size 607851 diff --git a/examples/images/004.png b/examples/images/004.png new file mode 100644 index 0000000000000000000000000000000000000000..d2bf0bae8a6cbb0a444db3b65f27370258e9bdc7 --- /dev/null +++ b/examples/images/004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19aa7e05ca0cb1eb4e7809eeded332cce8c21daf9e5458338b6ad3bfbba85679 +size 1298013 diff --git a/examples/images/008.png b/examples/images/008.png new file mode 100644 index 0000000000000000000000000000000000000000..0138f2fd4ebfd62c7b4610fbbf6cd7d009cf3f92 --- /dev/null +++ b/examples/images/008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67cf8e33b715641599c5489f06f6c5d1da312faf3c95196395d9d81a1aa112e1 +size 366617 diff --git a/examples/images/028.png b/examples/images/028.png new file mode 100644 index 0000000000000000000000000000000000000000..5604ef1b40f03ed2c75046764073edd84d405c80 --- /dev/null +++ b/examples/images/028.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b12c3b18f615fb5c887bfbd946c69eff8934519182ee5ef13f3853ca64e0bc22 +size 1298098 diff --git a/examples/images/032.png b/examples/images/032.png new file mode 100644 index 0000000000000000000000000000000000000000..a0966169c03bd161eca407ced77a78cf7897a9c6 --- /dev/null +++ b/examples/images/032.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f655fc199fed98a8d663e6e39baa94307af3e9494efa6389ac5b90c81b45b18 +size 1563171 diff --git a/examples/images/061.png b/examples/images/061.png new file mode 100644 index 0000000000000000000000000000000000000000..83bce6213408d5966b28b5970b798d9566707126 --- /dev/null +++ b/examples/images/061.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e28ffd293ba94f8d92c7bef7db7125d6df5e05287f116d6f93617623aa5d7ecf +size 306501 diff --git a/examples/images/107.png b/examples/images/107.png new file mode 100644 index 0000000000000000000000000000000000000000..6cf7d8a91f0d65461a77a17a145d33bf97b9dd2e --- /dev/null +++ b/examples/images/107.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70c7d618bfd70125d0b61007e549f3369273b1de866b30c703a68045bceb8950 +size 1271461 diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..785b00215741d808dbd29f57474dbc7cfa450c1a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +datasets==2.19.0 +diffusers==0.32.2 +einops==0.8.0 +huggingface-hub==0.26.2 +imageio==2.34.1 +jaxtyping==0.2.28 +joblib==1.4.0 +lightning-utilities==0.11.2 +matplotlib==3.8.4 +numpy==1.26.4 +omegaconf==2.3.0 +opencv-python-headless==4.10.0.84 +pandas==2.2.2 +pillow==10.3.0 +plyfile==1.0.3 +PyMCubes==0.1.4 +pyparsing==3.1.2 +pytorch-lightning==2.2.4 +PyYAML==6.0.1 +safetensors==0.4.3 +scikit-image==0.23.2 +scipy==1.13.0 +tensorboard==2.16.2 +tensorboardX==2.6.2.2 +timm==0.9.16 +tokenizers==0.21.0 +tqdm==4.66.2 +transformers==4.48.0 +trimesh==4.3.2 +spaces==0.28.3 +accelerate==1.5.2 +rembg==2.0.65 +gradio==5.5.0 +wandb==0.18.6 +deepspeed==0.16.4 +sageattention==1.0.6 +mosaicml-streaming==0.11.0 +easydict==1.13 +open3d==0.19.0 +prodigyopt==1.1.2 +peft==0.15.1 +sentencepiece==0.2.0 +pymeshlab==2023.12.post3 +onnxruntime==1.21.0 +bs4==0.0.2 +xatlas==0.0.10 +pybind11==2.13.6 +pygltflib==1.16.4 +kornia==0.8.0 +git+https://github.com/NVlabs/nvdiffrast.git diff --git a/step1x3d_geometry/__init__.py b/step1x3d_geometry/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..aaae1910181ee66ba098822f3009e44765e4a3a2 --- /dev/null +++ b/step1x3d_geometry/__init__.py @@ -0,0 +1,52 @@ +import importlib + +__modules__ = {} + + +def register(name): + def decorator(cls): + if name in __modules__: + raise ValueError( + f"Module {name} already exists! Names of extensions conflict!" + ) + else: + __modules__[name] = cls + return cls + + return decorator + + +def find(name): + if name in __modules__: + return __modules__[name] + else: + try: + module_string = ".".join(name.split(".")[:-1]) + cls_name = name.split(".")[-1] + module = importlib.import_module(module_string, package=None) + return getattr(module, cls_name) + except Exception as e: + raise ValueError(f"Module {name} not found!") + + +### grammar sugar for logging utilities ### +import logging + +logger = logging.getLogger("pytorch_lightning") + +from pytorch_lightning.utilities.rank_zero import ( + rank_zero_debug, + rank_zero_info, + rank_zero_only, +) + +debug = rank_zero_debug +info = rank_zero_info + + +@rank_zero_only +def warn(*args, **kwargs): + logger.warn(*args, **kwargs) + + +from . import data, models, systems diff --git a/step1x3d_geometry/data/Objaverse.py b/step1x3d_geometry/data/Objaverse.py new file mode 100755 index 0000000000000000000000000000000000000000..367f22926630731e35686f560919b29dd9cb0bd9 --- /dev/null +++ b/step1x3d_geometry/data/Objaverse.py @@ -0,0 +1,73 @@ +import math +import os +import json +import re +import cv2 +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from step1x3d_geometry import register +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.config import parse_structured + +from streaming import StreamingDataLoader +from .base import BaseDataModuleConfig, BaseDataset + + +@dataclass +class ObjaverseDataModuleConfig(BaseDataModuleConfig): + pass + + +class ObjaverseDataset(BaseDataset): + pass + + +@register("Objaverse-datamodule") +class ObjaverseDataModule(pl.LightningDataModule): + cfg: ObjaverseDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = ObjaverseDataset(self.cfg, "train") + if stage in [None, "fit", "validate"]: + self.val_dataset = ObjaverseDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = ObjaverseDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader( + self, dataset, batch_size, collate_fn=None, num_workers=0 + ) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_workers, + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, + batch_size=self.cfg.batch_size, + collate_fn=self.train_dataset.collate, + num_workers=self.cfg.num_workers, + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) diff --git a/step1x3d_geometry/data/__init__.py b/step1x3d_geometry/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a549c94236c5d26100912bd3abf3b51e8938ddec --- /dev/null +++ b/step1x3d_geometry/data/__init__.py @@ -0,0 +1 @@ +from . import Objaverse diff --git a/step1x3d_geometry/data/base.py b/step1x3d_geometry/data/base.py new file mode 100755 index 0000000000000000000000000000000000000000..5440369ee96545900ec15f278739ba5716f48d5a --- /dev/null +++ b/step1x3d_geometry/data/base.py @@ -0,0 +1,350 @@ +import math +import os +import json +import re +import cv2 +from dataclasses import dataclass, field + +import random +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.utils.data import DataLoader, Dataset +from PIL import Image + +from step1x3d_geometry.utils.typing import * + + +@dataclass +class BaseDataModuleConfig: + root_dir: str = None + batch_size: int = 4 + num_workers: int = 8 + + ################################# General argumentation ################################# + random_flip: bool = ( + False # whether to randomly flip the input point cloud and the input images + ) + + ################################# Geometry part ################################# + load_geometry: bool = True # whether to load geometry data + with_sharp_data: bool = False + geo_data_type: str = "sdf" # occupancy, sdf + # for occupancy or sdf supervision + n_samples: int = 4096 # number of points in input point cloud + upsample_ratio: int = 1 # upsample ratio for input point cloud + sampling_strategy: Optional[str] = ( + "random" # sampling strategy for input point cloud + ) + scale: float = 1.0 # scale of the input point cloud and target supervision + noise_sigma: float = 0.0 # noise level of the input point cloud + rotate_points: bool = ( + False # whether to rotate the input point cloud and the supervision, for VAE aug. + ) + load_geometry_supervision: bool = False # whether to load supervision + supervision_type: str = "sdf" # occupancy, sdf, tsdf, tsdf_w_surface + n_supervision: int = 10000 # number of points in supervision + tsdf_threshold: float = ( + 0.01 # threshold for truncating sdf values, used when input is sdf + ) + + ################################# Image part ################################# + load_image: bool = False # whether to load images + image_type: str = "rgb" # rgb, normal, rgb_or_normal + image_file_type: str = "png" # png, jpeg + image_type_ratio: float = ( + 1.0 # ratio of rgb for each dataset when image_type is "rgb_or_normal" + ) + crop_image: bool = True # whether to crop the input image + random_color_jitter: bool = ( + False # whether to randomly color jitter the input images + ) + random_rotate: bool = ( + False # whether to randomly rotate the input images, default [-10 deg, 10 deg] + ) + random_mask: bool = False # whether to add random mask to the input image + background_color: Tuple[int, int, int] = field( + default_factory=lambda: (255, 255, 255) + ) + idx: Optional[List[int]] = None # index of the image to load + n_views: int = 1 # number of views + foreground_ratio: Optional[float] = 0.90 + + ################################# Caption part ################################# + load_caption: bool = False # whether to load captions + load_label: bool = False # whether to load labels + + +class BaseDataset(Dataset): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.cfg: BaseDataModuleConfig = cfg + self.split = split + + self.uids = json.load(open(f"{cfg.root_dir}/{split}.json")) + print(f"Loaded {len(self.uids)} {split} uids") + + # add ColorJitter transforms for input images + if self.cfg.random_color_jitter: + self.color_jitter = transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 + ) + + # add RandomRotation transforms for input images + if self.cfg.random_rotate: + self.rotate = transforms.RandomRotation( + degrees=10, fill=(*self.cfg.background_color, 0.0) + ) # by default 10 deg + + def __len__(self): + return len(self.uids) + + def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: + if self.cfg.geo_data_type == "sdf": + data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") + # for input point cloud + surface = data["surface"] + if self.cfg.with_sharp_data: + sharp_surface = data["sharp_surface"] + else: + raise NotImplementedError( + f"Data type {self.cfg.geo_data_type} not implemented" + ) + + # random sampling + if self.cfg.sampling_strategy == "random": + rng = np.random.default_rng() + ind = rng.choice( + surface.shape[0], + self.cfg.upsample_ratio * self.cfg.n_samples, + replace=True, + ) + surface = surface[ind] + if self.cfg.with_sharp_data: + sharp_surface = sharp_surface[ind] + elif self.cfg.sampling_strategy == "fps": + import fpsample + + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( + surface[:, :3], self.cfg.n_samples, h=5 + ) + surface = surface[kdline_fps_samples_idx] + if self.cfg.with_sharp_data: + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( + sharp_surface[:, :3], self.cfg.n_samples, h=5 + ) + sharp_surface = sharp_surface[kdline_fps_samples_idx] + else: + raise NotImplementedError( + f"sampling strategy {self.cfg.sampling_strategy} not implemented" + ) + + # rescale data + surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale + if self.cfg.with_sharp_data: + sharp_surface[:, :3] = sharp_surface[:, :3] * self.cfg.scale # target scale + ret = { + "uid": self.uids[index].split("/")[-1], + "surface": surface.astype(np.float32), + "sharp_surface": sharp_surface.astype(np.float32), + } + else: + ret = { + "uid": self.uids[index].split("/")[-1], + "surface": surface.astype(np.float32), + } + + return ret + + def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: + # for supervision + ret = {} + if self.cfg.geo_data_type == "sdf": + data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") + data = np.concatenate( + [data["volume_rand_points"], data["near_surface_points"]], axis=0 + ) + rand_points, sdfs = data[:, :3], data[:, 3:] + else: + raise NotImplementedError( + f"Data type {self.cfg.geo_data_type} not implemented" + ) + + # random sampling + rng = np.random.default_rng() + ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) + rand_points = rand_points[ind] + rand_points = rand_points * self.cfg.scale + ret["rand_points"] = rand_points.astype(np.float32) + + if self.cfg.geo_data_type == "sdf": + if self.cfg.supervision_type == "sdf": + ret["sdf"] = sdfs[ind].flatten().astype(np.float32) + elif self.cfg.supervision_type == "occupancy": + ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype( + np.float32 + ) + elif self.cfg.supervision_type == "tsdf": + ret["sdf"] = ( + sdfs[ind] + .flatten() + .astype(np.float32) + .clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) + / self.cfg.tsdf_threshold + ) + else: + raise NotImplementedError( + f"Supervision type {self.cfg.supervision_type} not implemented" + ) + + return ret + + def _load_image(self, index: int) -> Dict[str, Any]: + def _process_img(image, background_color=(255, 255, 255), foreground_ratio=0.9): + alpha = image.getchannel("A") + background = Image.new("RGBA", image.size, (*background_color, 255)) + image = Image.alpha_composite(background, image) + image = image.crop(alpha.getbbox()) + + new_size = tuple(int(dim * foreground_ratio) for dim in image.size) + resized_image = image.resize(new_size) + padded_image = Image.new("RGBA", image.size, (*background_color, 255)) + paste_position = ( + (image.width - resized_image.width) // 2, + (image.height - resized_image.height) // 2, + ) + padded_image.paste(resized_image, paste_position) + + # Expand image to 1:1 + max_dim = max(padded_image.size) + image = Image.new("RGBA", (max_dim, max_dim), (*background_color, 255)) + paste_position = ( + (max_dim - padded_image.width) // 2, + (max_dim - padded_image.height) // 2, + ) + image.paste(padded_image, paste_position) + image = image.resize((512, 512)) + return image.convert("RGB"), alpha + + ret = {} + if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": + assert ( + self.cfg.n_views == 1 + ), "Only single view is supported for single image" + sel_idx = random.choice(self.cfg.idx) + ret["sel_image_idx"] = sel_idx + if self.cfg.image_type == "rgb": + img_path = ( + f"{self.cfg.root_dir}/images/" + + "/".join(self.uids[index].split("/")[-2:]) + + f"/{'{:04d}'.format(sel_idx)}_rgb.{self.cfg.image_file_type}" + ) + elif self.cfg.image_type == "normal": + img_path = ( + f"{self.cfg.root_dir}/images/" + + "/".join(self.uids[index].split("/")[-2:]) + + f"/{'{:04d}'.format(sel_idx)}_normal.{self.cfg.image_file_type}" + ) + image = Image.open(img_path).copy() + + # add random color jitter + if self.cfg.random_color_jitter: + rgb = self.color_jitter(image.convert("RGB")) + image = Image.merge("RGBA", (*rgb.split(), image.getchannel("A"))) + + # add random rotation + if self.cfg.random_rotate: + image = self.rotate(image) + + # add crop + if self.cfg.crop_image: + background_color = ( + torch.randint(0, 256, (3,)) + if self.cfg.background_color is None + else torch.as_tensor(self.cfg.background_color) + ) + image, alpha = _process_img( + image, background_color, self.cfg.foreground_ratio + ) + else: + alpha = image.getchannel("A") + background = Image.new("RGBA", image.size, background_color) + image = Image.alpha_composite(background, image).convert("RGB") + + ret["image"] = torch.from_numpy(np.array(image) / 255.0) + ret["mask"] = torch.from_numpy(np.array(alpha) / 255.0).unsqueeze(0) + else: + raise NotImplementedError( + f"Image type {self.cfg.image_type} not implemented" + ) + + return ret + + def _get_data(self, index): + ret = {"uid": self.uids[index]} + + # random flip + flip = np.random.rand() < 0.5 if self.cfg.random_flip else False + + # load geometry + if self.cfg.load_geometry: + if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf": + # load shape + ret = self._load_shape_from_occupancy_or_sdf(index) + # load supervision for shape + if self.cfg.load_geometry_supervision: + ret.update(self._load_shape_supervision_occupancy_or_sdf(index)) + else: + raise NotImplementedError( + f"Geo data type {self.cfg.geo_data_type} not implemented" + ) + + if flip: # random flip the input point cloud and the supervision + for key in ret.keys(): + if key in ["surface", "sharp_surface"]: # N x (xyz + normal) + ret[key][:, 0] = -ret[key][:, 0] + ret[key][:, 3] = -ret[key][:, 3] + elif key in ["rand_points"]: + ret[key][:, 0] = -ret[key][:, 0] + + # load image + if self.cfg.load_image: + ret.update(self._load_image(index)) + if flip: # random flip the input image + for key in ret.keys(): + if key in ["image"]: # random flip the input image + ret[key] = torch.flip(ret[key], [2]) + if key in ["mask"]: # random flip the input image + ret[key] = torch.flip(ret[key], [2]) + + # load caption + meta = None + if self.cfg.load_caption: + with open(f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r") as f: + meta = json.load(f) + ret.update({"caption": meta["caption"]}) + + # load label + if self.cfg.load_label: + if meta is None: + with open( + f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r" + ) as f: + meta = json.load(f) + ret.update({"label": [meta["label"]]}) + + return ret + + def __getitem__(self, index): + try: + return self._get_data(index) + except Exception as e: + print(f"Error in {self.uids[index]}: {e}") + return self.__getitem__(np.random.randint(len(self))) + + def collate(self, batch): + from torch.utils.data._utils.collate import default_collate_fn_map + + return torch.utils.data.default_collate(batch) diff --git a/step1x3d_geometry/models/__init__.py b/step1x3d_geometry/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..62ea00e590094dea226306595dcaa1983c29342e --- /dev/null +++ b/step1x3d_geometry/models/__init__.py @@ -0,0 +1 @@ +from . import autoencoders, conditional_encoders, transformers diff --git a/step1x3d_geometry/models/attention.py b/step1x3d_geometry/models/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..fb0ab12f1d003ef8b3b6aeba221103dc77a815ed --- /dev/null +++ b/step1x3d_geometry/models/attention.py @@ -0,0 +1,776 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union +import collections.abc +from itertools import repeat + +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, + FP32LayerNorm, + LayerNorm, +) + +from .attention_processor import FluxAttnProcessor2_0, AttnProcessor2_0 + + +@maybe_allow_in_graph +class MultiCondBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + use_self_attention: bool = True, + use_cross_attention: bool = False, + self_attention_norm_type: Optional[ + str + ] = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + cross_attention_dim: Optional[int] = None, + cross_attention_norm_type: Optional[str] = None, + # parallel second cross attention + use_cross_attention_2: bool = False, + cross_attention_2_dim: Optional[int] = None, + cross_attention_2_norm_type: Optional[str] = None, + # parallel third cross attention + use_cross_attention_3: bool = False, + cross_attention_3_dim: Optional[int] = None, + cross_attention_3_norm_type: Optional[str] = None, + dropout=0.0, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.use_self_attention = use_self_attention + self.use_cross_attention = use_cross_attention + self.self_attention_norm_type = self_attention_norm_type + self.cross_attention_dim = cross_attention_dim + self.cross_attention_norm_type = cross_attention_norm_type + self.use_cross_attention_2 = use_cross_attention_2 + self.cross_attention_2_dim = cross_attention_2_dim + self.cross_attention_2_norm_type = cross_attention_2_norm_type + self.use_cross_attention_3 = use_cross_attention_3 + self.cross_attention_3_dim = cross_attention_3_dim + self.cross_attention_3_norm_type = cross_attention_3_norm_type + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and self_attention_norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and self_attention_norm_type == "ada_norm" + self.use_ada_layer_norm_single = self_attention_norm_type == "ada_norm_single" + self.use_layer_norm = self_attention_norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = ( + self_attention_norm_type == "ada_norm_continuous" + ) + + if ( + self_attention_norm_type in ("ada_norm", "ada_norm_zero") + and num_embeds_ada_norm is None + ): + raise ValueError( + f"`self_attention_norm_type` is set to {self_attention_norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `self_attention_norm_type` to {self_attention_norm_type}." + ) + + self.self_attention_norm_type = self_attention_norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding( + dim, max_seq_length=num_positional_embeddings + ) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + if use_self_attention: + # 1. Self-Attn + if self_attention_norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self_attention_norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self_attention_norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + elif ( + self_attention_norm_type == "fp32_layer_norm" + or self_attention_norm_type is None + ): + self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm1 = nn.RMSNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=dim // num_attention_heads, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=( + cross_attention_dim if only_cross_attention else None + ), + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + processor=AttnProcessor2_0(), + ) + + # 2. Cross-Attn + if use_cross_attention or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if cross_attention_norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif cross_attention_norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + elif ( + cross_attention_norm_type == "fp32_layer_norm" + or cross_attention_norm_type is None + ): + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = nn.RMSNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=dim // num_attention_heads, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + processor=AttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 2'. Parallel Second Cross-Attn + if use_cross_attention_2: + assert cross_attention_2_dim is not None + if cross_attention_2_norm_type == "ada_norm": + self.norm2_2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif cross_attention_2_norm_type == "ada_norm_continuous": + self.norm2_2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + elif ( + cross_attention_2_norm_type == "fp32_layer_norm" + or cross_attention_2_norm_type is None + ): + self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2_2 = nn.RMSNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn2_2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_2_dim, + heads=num_attention_heads, + dim_head=dim // num_attention_heads, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + processor=AttnProcessor2_0(), + ) + + # self.attn2_2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_2_dim, + # dim_head=dim // num_attention_heads, + # heads=num_attention_heads, + # qk_norm="rms_norm" if qk_norm else None, + # cross_attention_norm=cross_attention_2_norm_type, + # eps=1e-6, + # bias=qkv_bias, + # processor=AttnProcessor2_0(), + # ) + else: + self.norm2_2 = None + self.attn2_2 = None + + # 2'. Parallel Third Cross-Attn + if use_cross_attention_3: + assert cross_attention_3_dim is not None + if cross_attention_3_norm_type == "ada_norm": + self.norm2_3 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif cross_attention_3_norm_type == "ada_norm_continuous": + self.norm2_3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + elif ( + cross_attention_3_norm_type == "fp32_layer_norm" + or cross_attention_3_norm_type is None + ): + self.norm2_3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2_3 = nn.RMSNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn2_3 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_3_dim, + heads=num_attention_heads, + dim_head=dim // num_attention_heads, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + processor=AttnProcessor2_0(), + ) + else: + self.norm2_3 = None + self.attn2_3 = None + + # 3. Feed-forward + if self_attention_norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif self_attention_norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif self_attention_norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense( + dim, cross_attention_dim, num_attention_heads, attention_head_dim + ) + + # 5. Scale-shift for PixArt-Alpha. + if self_attention_norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + encoder_hidden_states_3: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask_2: Optional[torch.Tensor] = None, + encoder_attention_mask_3: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.self_attention_norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.self_attention_norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.self_attention_norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.self_attention_norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif self.self_attention_norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.self_attention_norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.self_attention_norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.cross_attention_norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.cross_attention_norm_type in [ + "ada_norm_zero", + "layer_norm", + "layer_norm_i2vgen", + ]: + norm_hidden_states = self.norm2(hidden_states) + elif self.cross_attention_norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.cross_attention_norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + else: + raise ValueError("Incorrect norm") + + if ( + self.pos_embed is not None + and self.cross_attention_norm_type != "ada_norm_single" + ): + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3.1 Parallel Second Cross-Attention + if self.attn2_2 is not None: + if self.cross_attention_2_norm_type == "ada_norm": + norm_hidden_states = self.norm2_2(hidden_states, timestep) + elif self.cross_attention_2_norm_type in [ + "ada_norm_zero", + "layer_norm", + "layer_norm_i2vgen", + ]: + norm_hidden_states = self.norm2_2(hidden_states) + elif self.cross_attention_2_norm_type == "ada_norm_single": + # For PixArt norm2_2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.cross_attention_2_norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2_2( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + else: + raise ValueError("Incorrect norm") + + if ( + self.pos_embed is not None + and self.cross_attention_2_norm_type != "ada_norm_single" + ): + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output_2 = self.attn2_2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states_2, + attention_mask=encoder_attention_mask_2, + **cross_attention_kwargs, + ) + hidden_states = attn_output_2 + hidden_states + + # 3.2 Parallel Third Cross-Attention + if self.attn2_3 is not None: + if self.cross_attention_3_norm_type == "ada_norm": + norm_hidden_states = self.norm2_3(hidden_states, timestep) + elif self.cross_attention_3_norm_type in [ + "ada_norm_zero", + "layer_norm", + "layer_norm_i2vgen", + ]: + norm_hidden_states = self.norm2_3(hidden_states) + elif self.cross_attention_3_norm_type == "ada_norm_single": + # For PixArt norm2_3 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.cross_attention_3_norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2_3( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + else: + raise ValueError("Incorrect norm") + + if ( + self.pos_embed is not None + and self.cross_attention_3_norm_type != "ada_norm_single" + ): + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output_3 = self.attn2_3( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states_3, + attention_mask=encoder_attention_mask_3, + **cross_attention_kwargs, + ) + hidden_states = attn_output_3 + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.self_attention_norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif not self.self_attention_norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.self_attention_norm_type == "ada_norm_zero": + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self.self_attention_norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.self_attention_norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.self_attention_norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + if is_torch_npu_available(): + deprecation_message = ( + "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " + "should be set explicitly using the `set_attn_processor` method." + ) + deprecate("npu_processor", "0.34.0", deprecation_message) + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=FluxAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + + mlp_ratio = 4.0 + self.mlp_hidden_dim = int(dim * mlp_ratio) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states diff --git a/step1x3d_geometry/models/attention_processor.py b/step1x3d_geometry/models/attention_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..b246f4c990374e80f7e8f401dd9f396ac789a731 --- /dev/null +++ b/step1x3d_geometry/models/attention_processor.py @@ -0,0 +1,482 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Tuple, Union + +import os +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.utils import logging +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph +from einops import rearrange +from torch import nn + +# add sageattention support +scaled_dot_product_attention = F.scaled_dot_product_attention +if os.environ.get("USE_SAGEATTN", "0") == "1": + try: + from sageattention import sageattn + except ImportError: + raise ImportError( + 'Please install the package "sageattention" to use this USE_SAGEATTN.' + ) + scaled_dot_product_attention = sageattn + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FusedAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses + fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. + For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states diff --git a/step1x3d_geometry/models/autoencoders/__init__.py b/step1x3d_geometry/models/autoencoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..623726edddecfa7d1aec09d765f8d58d0fc581ed --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/__init__.py @@ -0,0 +1,3 @@ +from . import ( + michelangelo_autoencoder, +) diff --git a/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py b/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..130cc84e0d0d7d8462405aa727306ec3f011a327 --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py @@ -0,0 +1,765 @@ +from dataclasses import dataclass +import math + +import torch +import numpy as np +import random +import time +import trimesh +import torch.nn as nn +from einops import repeat, rearrange +from tqdm import trange +from itertools import product +from diffusers.models.modeling_utils import ModelMixin + +import step1x3d_geometry +from step1x3d_geometry.utils.checkpoint import checkpoint +from step1x3d_geometry.utils.base import BaseModule +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.misc import get_world_size, get_device + +from .transformers.perceiver_1d import Perceiver +from .transformers.attention import ResidualCrossAttentionBlock +from .volume_decoders import HierarchicalVolumeDecoder, VanillaVolumeDecoder +from .surface_extractors import MCSurfaceExtractor, DMCSurfaceExtractor + +from ..pipelines.pipeline_utils import smart_load_model +from safetensors.torch import load_file + +VALID_EMBED_TYPES = ["identity", "fourier", "learned_fourier", "siren"] + + +class FourierEmbedder(nn.Module): + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + ) -> None: + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view( + *x.shape[:-1], -1 + ) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + def __init__(self, input_dim, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // input_dim + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + self.out_dim = self.get_dims(input_dim) + + def forward(self, x): + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + def get_dims(self, input_dim): + return input_dim * (self.weights.shape[0] * 2 + 1) + + +class Sine(nn.Module): + def __init__(self, w0=1.0): + super().__init__() + self.w0 = w0 + + def forward(self, x): + return torch.sin(self.w0 * x) + + +class Siren(nn.Module): + def __init__( + self, + in_dim, + out_dim, + w0=1.0, + c=6.0, + is_first=False, + use_bias=True, + activation=None, + dropout=0.0, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.is_first = is_first + + weight = torch.zeros(out_dim, in_dim) + bias = torch.zeros(out_dim) if use_bias else None + self.init_(weight, bias, c=c, w0=w0) + + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) if use_bias else None + self.activation = Sine(w0) if activation is None else activation + self.dropout = nn.Dropout(dropout) + + def init_(self, weight, bias, c, w0): + dim = self.in_dim + + w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) + weight.uniform_(-w_std, w_std) + + if bias is not None: + bias.uniform_(-w_std, w_std) + + def forward(self, x): + out = F.linear(x, self.weight, self.bias) + out = self.activation(out) + out = self.dropout(out) + return out + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + elif embed_type == "learned_fourier": + embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) + + elif embed_type == "siren": + embedder_obj = Siren( + in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim + ) + + else: + raise ValueError( + f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}" + ) + return embedder_obj + + +###################### AutoEncoder +class DiagonalGaussianDistribution(ModelMixin, object): + def __init__( + self, + parameters: Union[torch.Tensor, List[torch.Tensor]], + deterministic=False, + feat_dim=1, + ): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2)): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=dims, + ) + + def nll(self, sample, dims=(1, 2)): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +class PerceiverCrossAttentionEncoder(ModelMixin, nn.Module): + def __init__( + self, + use_downsample: bool, + num_latents: int, + embedder: FourierEmbedder, + point_feats: int, + embed_point_feats: bool, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + qk_norm: bool = True, + use_ln_post: bool = False, + use_flash: bool = False, + use_checkpoint: bool = False, + use_multi_reso: bool = False, + resolutions: list = [], + sampling_prob: list = [], + with_sharp_data: bool = False, + ): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + self.use_downsample = use_downsample + self.embed_point_feats = embed_point_feats + self.use_multi_reso = use_multi_reso + self.resolutions = resolutions + self.sampling_prob = sampling_prob + + if not self.use_downsample: + self.query = nn.Parameter(torch.randn((num_latents, width)) * 0.02) + + self.embedder = embedder + if self.embed_point_feats: + self.input_proj = nn.Linear(self.embedder.out_dim * 2, width) + else: + self.input_proj = nn.Linear(self.embedder.out_dim + point_feats, width) + + self.cross_attn = ResidualCrossAttentionBlock( + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + ) + + self.with_sharp_data = with_sharp_data + if with_sharp_data: + self.downsmaple_num_latents = num_latents // 2 + self.input_proj_sharp = nn.Linear( + self.embedder.out_dim + point_feats, width + ) + self.cross_attn_sharp = ResidualCrossAttentionBlock( + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + ) + else: + self.downsmaple_num_latents = num_latents + + self.self_attn = Perceiver( + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + use_checkpoint=use_checkpoint, + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + + bs, N, D = pc.shape + + data = self.embedder(pc) + if feats is not None: + if self.embed_point_feats: + feats = self.embedder(feats) + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + if self.with_sharp_data: + sharp_data = self.embedder(sharp_pc) + if sharp_feat is not None: + if self.embed_point_feats: + sharp_feat = self.embedder(sharp_feat) + sharp_data = torch.cat([sharp_data, sharp_feat], dim=-1) + sharp_data = self.input_proj_sharp(sharp_data) + + if self.use_multi_reso: + resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[ + 0 + ] + + if resolution != N: + flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3 + batch = torch.arange(bs).to(pc.device) # 103 + batch = torch.repeat_interleave(batch, N) # bs*N. 421888 + pos = flattened.to(torch.float16) + ratio = 1.0 * resolution / N # 0.0625 + idx = fps(pos, batch, ratio=ratio) # 26368 + pc = pc.view(bs * N, -1)[idx].view(bs, -1, D) + bs, N, D = feats.shape + flattened1 = feats.view(bs * N, D) + feats = flattened1.view(bs * N, -1)[idx].view(bs, -1, D) + bs, N, D = pc.shape + + if self.use_downsample: + ###### fps + from torch_cluster import fps + + flattened = pc.view(bs * N, D) # bs*N, 64 + + batch = torch.arange(bs).to(pc.device) + batch = torch.repeat_interleave(batch, N) # bs*N + + pos = flattened.to(torch.float16) + ratio = 1.0 * self.downsmaple_num_latents / N + idx = fps(pos, batch, ratio=ratio).detach() + query = data.view(bs * N, -1)[idx].view(bs, -1, data.shape[-1]) + + if self.with_sharp_data: + bs, N, D = sharp_pc.shape + flattened = sharp_pc.view(bs * N, D) # bs*N, 64 + pos = flattened.to(torch.float16) + ratio = 1.0 * self.downsmaple_num_latents / N + idx = fps(pos, batch, ratio=ratio).detach() + sharp_query = sharp_data.view(bs * N, -1)[idx].view( + bs, -1, sharp_data.shape[-1] + ) + query = torch.cat([query, sharp_query], dim=1) + else: + query = self.query + query = repeat(query, "m c -> b m c", b=bs) + + latents = self.cross_attn(query, data) + if self.with_sharp_data: + latents = latents + self.cross_attn_sharp(query, sharp_data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents + + def forward( + self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sharp_pc: Optional[torch.FloatTensor] = None, + sharp_feats: Optional[torch.FloatTensor] = None, + ): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint( + self._forward, + (pc, feats, sharp_pc, sharp_feats), + self.parameters(), + self.use_checkpoint, + ) + + +class PerceiverCrossAttentionDecoder(ModelMixin, nn.Module): + + def __init__( + self, + num_latents: int, + out_dim: int, + embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + qk_norm: bool = True, + use_flash: bool = False, + use_checkpoint: bool = False, + ): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.embedder = embedder + + self.query_proj = nn.Linear(self.embedder.out_dim, width) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + ) + + self.ln_post = nn.LayerNorm(width) + self.output_proj = nn.Linear(width, out_dim) + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + queries = self.query_proj(self.embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x = self.output_proj(x) + return x + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint( + self._forward, (queries, latents), self.parameters(), self.use_checkpoint + ) + + +@step1x3d_geometry.register("michelangelo-autoencoder") +class MichelangeloAutoencoder(BaseModule): + r""" + A VAE model for encoding shapes into latents and decoding latent representations into shapes. + """ + + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "" + subfolder: str = "" + n_samples: int = 4096 + use_downsample: bool = False + downsample_ratio: float = 0.0625 + num_latents: int = 256 + point_feats: int = 0 + embed_point_feats: bool = False + out_dim: int = 1 + embed_dim: int = 64 + embed_type: str = "fourier" + num_freqs: int = 8 + include_pi: bool = True + width: int = 768 + heads: int = 12 + num_encoder_layers: int = 8 + num_decoder_layers: int = 16 + init_scale: float = 0.25 + qkv_bias: bool = True + qk_norm: bool = False + use_ln_post: bool = False + use_flash: bool = False + use_checkpoint: bool = True + use_multi_reso: Optional[bool] = False + resolutions: Optional[List[int]] = None + sampling_prob: Optional[List[float]] = None + with_sharp_data: Optional[bool] = True + volume_decoder_type: str = "hierarchical" + surface_extractor_type: str = "mc" + z_scale_factor: float = 1.0 + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.embedder = get_embedder( + embed_type=self.cfg.embed_type, + num_freqs=self.cfg.num_freqs, + include_pi=self.cfg.include_pi, + ) + + # encoder + self.cfg.init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) + self.encoder = PerceiverCrossAttentionEncoder( + use_downsample=self.cfg.use_downsample, + embedder=self.embedder, + num_latents=self.cfg.num_latents, + point_feats=self.cfg.point_feats, + embed_point_feats=self.cfg.embed_point_feats, + width=self.cfg.width, + heads=self.cfg.heads, + layers=self.cfg.num_encoder_layers, + init_scale=self.cfg.init_scale, + qkv_bias=self.cfg.qkv_bias, + qk_norm=self.cfg.qk_norm, + use_ln_post=self.cfg.use_ln_post, + use_flash=self.cfg.use_flash, + use_checkpoint=self.cfg.use_checkpoint, + use_multi_reso=self.cfg.use_multi_reso, + resolutions=self.cfg.resolutions, + sampling_prob=self.cfg.sampling_prob, + with_sharp_data=self.cfg.with_sharp_data, + ) + + if self.cfg.embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(self.cfg.width, self.cfg.embed_dim * 2) + self.post_kl = nn.Linear(self.cfg.embed_dim, self.cfg.width) + self.latent_shape = (self.cfg.num_latents, self.cfg.embed_dim) + else: + self.latent_shape = (self.cfg.num_latents, self.cfg.width) + + self.transformer = Perceiver( + n_ctx=self.cfg.num_latents, + width=self.cfg.width, + layers=self.cfg.num_decoder_layers, + heads=self.cfg.heads, + init_scale=self.cfg.init_scale, + qkv_bias=self.cfg.qkv_bias, + qk_norm=self.cfg.qk_norm, + use_flash=self.cfg.use_flash, + use_checkpoint=self.cfg.use_checkpoint, + ) + + # decoder + self.decoder = PerceiverCrossAttentionDecoder( + embedder=self.embedder, + out_dim=self.cfg.out_dim, + num_latents=self.cfg.num_latents, + width=self.cfg.width, + heads=self.cfg.heads, + init_scale=self.cfg.init_scale, + qkv_bias=self.cfg.qkv_bias, + qk_norm=self.cfg.qk_norm, + use_flash=self.cfg.use_flash, + use_checkpoint=self.cfg.use_checkpoint, + ) + + # volume decoder + if self.cfg.volume_decoder_type == "hierarchical": + self.volume_decoder = HierarchicalVolumeDecoder() + else: + self.volume_decoder = VanillaVolumeDecoder() + + if self.cfg.pretrained_model_name_or_path != "": + local_model_path = f"{smart_load_model(self.cfg.pretrained_model_name_or_path, self.cfg.subfolder)}/vae/diffusion_pytorch_model.safetensors" + pretrain_safetensors = load_file(local_model_path) + print(f"Loading pretrained VAE model from {local_model_path}") + + if "state_dict" in pretrain_safetensors: + _pretrained_safetensors = {} + for k, v in pretrain_safetensors["state_dict"].items(): + if k.startswith("shape_model."): + if "proj1" in k: + _pretrained_safetensors[ + k.replace("shape_model.", "").replace( + "proj1", "proj_sharp" + ) + ] = v + elif "attn1" in k: + _pretrained_safetensors[ + k.replace("shape_model.", "").replace( + "attn1", "attn_sharp" + ) + ] = v + else: + _pretrained_safetensors[k.replace("shape_model.", "")] = v + + pretrain_safetensors = _pretrained_safetensors + self.load_state_dict(pretrain_safetensors, strict=True) + else: + _pretrained_safetensors = {} + for k, v in pretrain_safetensors.items(): + if k.startswith("shape_model"): + final_module = self + for key in k.replace("shape_model.", "").split("."): + final_module = getattr(final_module, key) + data = final_module.data + data_zero = torch.zeros_like(data).to(v) + + if data.shape != v.shape: + if data.ndim == 1: + data_zero[: v.shape[0]] = v + elif data.ndim == 2: + data_zero[: v.shape[0], : v.shape[1]] = v + v = data_zero + + _pretrained_safetensors[k.replace("shape_model.", "")] = v + else: + _pretrained_safetensors[k] = v + pretrain_safetensors = _pretrained_safetensors + self.load_state_dict(pretrain_safetensors, strict=True) + print("Successed load pretrained VAE model") + + def encode( + self, + surface: torch.FloatTensor, + sample_posterior: bool = True, + sharp_surface: torch.FloatTensor = None, + ): + """ + Args: + surface (torch.FloatTensor): [B, N, 3+C] + sample_posterior (bool): + + Returns: + shape_latents (torch.FloatTensor): [B, num_latents, width] + kl_embed (torch.FloatTensor): [B, num_latents, embed_dim] + posterior (DiagonalGaussianDistribution or None): + """ + assert ( + surface.shape[-1] == 3 + self.cfg.point_feats + ), f"\ + Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" + + pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 + if sharp_surface is not None: + sharp_pc, sharp_feats = ( + sharp_surface[..., :3], + sharp_surface[..., 3:], + ) # B, n_samples, 3 + else: + sharp_pc, sharp_feats = None, None + + shape_embeds = self.encoder( + pc, feats, sharp_pc, sharp_feats + ) # B, num_latents, width + kl_embed, posterior = self.encode_kl_embed( + shape_embeds, sample_posterior + ) # B, num_latents, embed_dim + + kl_embed = kl_embed * self.cfg.z_scale_factor # encode with scale + + return shape_embeds, kl_embed, posterior + + def decode(self, latents: torch.FloatTensor): + """ + Args: + latents (torch.FloatTensor): [B, embed_dim] + + Returns: + latents (torch.FloatTensor): [B, embed_dim] + """ + latents = self.post_kl( + latents / self.cfg.z_scale_factor + ) # [B, num_latents, embed_dim] -> [B, num_latents, width] + + return self.transformer(latents) + + def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + """ + Args: + queries (torch.FloatTensor): [B, N, 3] + latents (torch.FloatTensor): [B, embed_dim] + + Returns: + features (torch.FloatTensor): [B, N, C], output features + """ + + features = self.decoder(queries, latents) + + return features + + def encode_kl_embed( + self, latents: torch.FloatTensor, sample_posterior: bool = True + ): + posterior = None + if self.cfg.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + return kl_embed, posterior + + def forward( + self, + surface: torch.FloatTensor, + sharp_surface: torch.FloatTensor = None, + rand_points: torch.FloatTensor = None, + sample_posterior: bool = True, + **kwargs, + ): + shape_latents, kl_embed, posterior = self.encode( + surface, sample_posterior=sample_posterior, sharp_surface=sharp_surface + ) + + latents = self.decode(kl_embed) # [B, num_latents, width] + + meshes = self.extract_geometry(latents, **kwargs) + + return shape_latents, latents, posterior, meshes + + def extract_geometry(self, latents: torch.FloatTensor, **kwargs): + + grid_logits_list = [] + for i in range(latents.shape[0]): + grid_logits = self.volume_decoder( + latents[i].unsqueeze(0), self.query, **kwargs + ) + grid_logits_list.append(grid_logits) + grid_logits = torch.cat(grid_logits_list, dim=0) + + # extract mesh + surface_extractor_type = ( + kwargs["surface_extractor_type"] + if "surface_extractor_type" in kwargs.keys() + and kwargs["surface_extractor_type"] is not None + else self.cfg.surface_extractor_type + ) + + if surface_extractor_type == "mc": + surface_extractor = MCSurfaceExtractor() + meshes = surface_extractor(grid_logits, **kwargs) + elif surface_extractor_type == "dmc": + surface_extractor = DMCSurfaceExtractor() + meshes = surface_extractor(grid_logits, **kwargs) + else: + raise NotImplementedError + + return meshes diff --git a/step1x3d_geometry/models/autoencoders/surface_extractors.py b/step1x3d_geometry/models/autoencoders/surface_extractors.py new file mode 100755 index 0000000000000000000000000000000000000000..067307df42c9c92cd477628cbd641d6ddaa194bd --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/surface_extractors.py @@ -0,0 +1,137 @@ +from typing import Union, Tuple, List + +import numpy as np +import torch +from skimage import measure + + +class MeshExtractResult: + def __init__(self, verts, faces, vertex_attrs=None, res=64): + self.verts = verts + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals() + self.vert_normal = self.comput_v_normals() + self.res = res + self.success = verts.shape[0] != 0 and faces.shape[0] != 0 + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self): + i0 = self.faces[..., 0].long() + i1 = self.faces[..., 1].long() + i2 = self.faces[..., 2].long() + + v0 = self.verts[i0, :] + v1 = self.verts[i1, :] + v2 = self.verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self): + i0 = self.faces[..., 0].long() + i1 = self.faces[..., 1].long() + i2 = self.faces[..., 2].long() + + v0 = self.verts[i0, :] + v1 = self.verts[i1, :] + v2 = self.verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(self.verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +def center_vertices(vertices): + """Translate the vertices so that bounding box is centered at zero.""" + vert_min = vertices.min(dim=0)[0] + vert_max = vertices.max(dim=0)[0] + vert_center = 0.5 * (vert_min + vert_max) + return vertices - vert_center + + +class SurfaceExtractor: + def _compute_box_stat( + self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int + ): + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + grid_size = [ + int(octree_resolution) + 1, + int(octree_resolution) + 1, + int(octree_resolution) + 1, + ] + return grid_size, bbox_min, bbox_size + + def run(self, *args, **kwargs): + return NotImplementedError + + def __call__(self, grid_logits, **kwargs): + outputs = [] + for i in range(grid_logits.shape[0]): + try: + verts, faces = self.run(grid_logits[i], **kwargs) + outputs.append( + MeshExtractResult( + verts=verts.float(), + faces=faces, + res=kwargs["octree_resolution"], + ) + ) + + except Exception: + import traceback + + traceback.print_exc() + outputs.append(None) + + return outputs + + +class MCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): + verts, faces, normals, _ = measure.marching_cubes( + grid_logit.float().cpu().numpy(), mc_level, method="lewiner" + ) + grid_size, bbox_min, bbox_size = self._compute_box_stat( + bounds, octree_resolution + ) + verts = verts / grid_size * bbox_size + bbox_min + verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32) + faces = torch.tensor( + np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long + ) + faces = faces[:, [2, 1, 0]] + return verts, faces + + +class DMCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, octree_resolution, **kwargs): + device = grid_logit.device + if not hasattr(self, "dmc"): + try: + from diso import DiffDMC + except: + raise ImportError( + "Please install diso via `pip install diso`, or set mc_algo to 'mc'" + ) + self.dmc = DiffDMC(dtype=torch.float32).to(device) + sdf = -grid_logit / octree_resolution + sdf = sdf.to(torch.float32).contiguous() + verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) + grid_size, bbox_min, bbox_size = self._compute_box_stat( + kwargs["bounds"], octree_resolution + ) + verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"] + return verts, faces diff --git a/step1x3d_geometry/models/autoencoders/transformers/attention.py b/step1x3d_geometry/models/autoencoders/transformers/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..61e2b583a37a89caffabb6a4c4f65514db19f6ab --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/transformers/attention.py @@ -0,0 +1,286 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.checkpoint import checkpoint + +from .utils import init_linear, MLP +from timm.models.vision_transformer import Attention + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + qk_norm: bool, + norm_layer=nn.LayerNorm, + use_flash: bool = False, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + n_ctx=n_ctx, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + use_flash=use_flash, + ) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_ctx: int, + width=None, + qk_norm: bool = False, + norm_layer=nn.LayerNorm, + use_flash: bool = False, + ): + super().__init__() + self.heads = heads + self.n_ctx = n_ctx + self.use_flash = use_flash + + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + if self.use_flash: + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + out = ( + F.scaled_dot_product_attention(q, k, v) + .permute(0, 2, 1, 3) + .reshape(bs, n_ctx, -1) + ) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = True, + use_flash: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + use_flash=use_flash, + ) + self.ln_1 = nn.LayerNorm(width) + self.mlp = MLP(width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = True, + use_flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + n_data=n_data, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + use_flash=use_flash, + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_data: Optional[int] = None, + width=None, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + use_flash: bool = False, + ): + + super().__init__() + self.heads = heads + self.n_data = n_data + self.use_flash = use_flash + + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + if self.use_flash: + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + out = ( + F.scaled_dot_product_attention(q, k, v) + .permute(0, 2, 1, 3) + .reshape(bs, n_ctx, -1) + ) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + qk_norm: bool = True, + use_flash: bool = False, + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + ) + self.ln_1 = nn.LayerNorm(width) + self.ln_2 = nn.LayerNorm(data_width) + self.mlp = MLP(width=width, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x diff --git a/step1x3d_geometry/models/autoencoders/transformers/perceiver_1d.py b/step1x3d_geometry/models/autoencoders/transformers/perceiver_1d.py new file mode 100755 index 0000000000000000000000000000000000000000..a77204d271d1f1fca7503567703b16ec5922a654 --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/transformers/perceiver_1d.py @@ -0,0 +1,50 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.checkpoint import checkpoint + +from .utils import init_linear +from .attention import ResidualAttentionBlock + + +class Perceiver(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + qk_norm: bool = True, + use_flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + use_flash=use_flash, + use_checkpoint=use_checkpoint, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/step1x3d_geometry/models/autoencoders/transformers/utils.py b/step1x3d_geometry/models/autoencoders/transformers/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5c0c94cc2ccd731d0bc544aee6f417058186421c --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/transformers/utils.py @@ -0,0 +1,21 @@ +import torch.nn as nn + + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + + +class MLP(nn.Module): + def __init__(self, *, width: int, init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) diff --git a/step1x3d_geometry/models/autoencoders/volume_decoders.py b/step1x3d_geometry/models/autoencoders/volume_decoders.py new file mode 100755 index 0000000000000000000000000000000000000000..6e3d671a69e15a5c82ec9941d2052f55795f72af --- /dev/null +++ b/step1x3d_geometry/models/autoencoders/volume_decoders.py @@ -0,0 +1,327 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from typing import Union, Tuple, List, Callable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from tqdm import tqdm + +cube_corners = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=torch.int, +) + + +def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float): + device = input_tensor.device + D = input_tensor.shape[0] + signed_val = 0.0 + + # 添加偏移并处理无效值 + val = input_tensor + alpha + valid_mask = val > -9000 # 假设-9000是无效值 + + # 改进的邻居获取函数(保持维度一致) + def get_neighbor(t, shift, axis): + """根据指定轴进行位移并保持维度一致""" + if shift == 0: + return t.clone() + + # 确定填充轴(输入为[D, D, D]对应z,y,x轴) + pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后] + + # 根据轴类型设置填充 + if axis == 0: # x轴(最后一个维度) + pad_idx = 0 if shift > 0 else 1 + pad_dims[pad_idx] = abs(shift) + elif axis == 1: # y轴(中间维度) + pad_idx = 2 if shift > 0 else 3 + pad_dims[pad_idx] = abs(shift) + elif axis == 2: # z轴(第一个维度) + pad_idx = 4 if shift > 0 else 5 + pad_dims[pad_idx] = abs(shift) + + # 执行填充(添加batch和channel维度适配F.pad) + padded = F.pad( + t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode="replicate" + ) # 反转顺序适配F.pad + + # 构建动态切片索引 + slice_dims = [slice(None)] * 3 # 初始化为全切片 + if axis == 0: # x轴(dim=2) + if shift > 0: + slice_dims[0] = slice(shift, None) + else: + slice_dims[0] = slice(None, shift) + elif axis == 1: # y轴(dim=1) + if shift > 0: + slice_dims[1] = slice(shift, None) + else: + slice_dims[1] = slice(None, shift) + elif axis == 2: # z轴(dim=0) + if shift > 0: + slice_dims[2] = slice(shift, None) + else: + slice_dims[2] = slice(None, shift) + + # 应用切片并恢复维度 + padded = padded.squeeze(0).squeeze(0) + sliced = padded[slice_dims] + return sliced + + # 获取各方向邻居(确保维度一致) + left = get_neighbor(val, 1, axis=0) # x方向 + right = get_neighbor(val, -1, axis=0) + back = get_neighbor(val, 1, axis=1) # y方向 + front = get_neighbor(val, -1, axis=1) + down = get_neighbor(val, 1, axis=2) # z方向 + up = get_neighbor(val, -1, axis=2) + + # 处理边界无效值(使用where保持维度一致) + def safe_where(neighbor): + return torch.where(neighbor > -9000, neighbor, val) + + left = safe_where(left) + right = safe_where(right) + back = safe_where(back) + front = safe_where(front) + down = safe_where(down) + up = safe_where(up) + + # 计算符号一致性(转换为float32确保精度) + sign = torch.sign(val.to(torch.float32)) + neighbors_sign = torch.stack( + [ + torch.sign(left.to(torch.float32)), + torch.sign(right.to(torch.float32)), + torch.sign(back.to(torch.float32)), + torch.sign(front.to(torch.float32)), + torch.sign(down.to(torch.float32)), + torch.sign(up.to(torch.float32)), + ], + dim=0, + ) + + # 检查所有符号是否一致 + same_sign = torch.all(neighbors_sign == sign, dim=0) + + # 生成最终掩码 + mask = (~same_sign).to(torch.int32) + return mask * valid_mask.to(torch.int32) + + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = 384, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij", + ) + xyz_samples = ( + torch.from_numpy(xyz_samples) + .to(device, dtype=dtype) + .contiguous() + .reshape(-1, 3) + ) + + # 2. latents to 3d volume + batch_features = [] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc=f"Volume Decoding", + disable=not enable_pbar, + ): + chunk_queries = xyz_samples[start : start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + features = geo_decoder(queries=chunk_queries, latents=latents) + batch_features.append(features) + + grid_features = torch.cat(batch_features, dim=1) + grid_logits, grid_features = grid_features[..., 0:1], grid_features[..., 1:] + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits, xyz_samples, grid_features, None + + +class HierarchicalVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 65536, + mc_level: float = 0.0, + octree_resolution: int = 384, + min_resolution: int = 63, + enable_pbar: bool = True, + empty_value: float = float("nan"), + **kwargs, + ): + device = latents.device + dtype = latents.dtype + + resolutions = [] + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij", + ) + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter( + torch.ones(dilate.weight.shape, dtype=dtype, device=device) + ) + + grid_size = np.array(grid_size) + xyz_samples = ( + torch.from_numpy(xyz_samples) + .to(device, dtype=dtype) + .contiguous() + .reshape(-1, 3) + ) + + # 2. latents to 3d volume + batch_features = [] + batch_size = latents.shape[0] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]", + disable=not enable_pbar, + ): + queries = xyz_samples[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + features = geo_decoder(queries=batch_queries, latents=latents) + batch_features.append(features) + + grid_features = torch.cat(batch_features, dim=1).view( + (batch_size, grid_size[0], grid_size[1], grid_size[2], -1) + ) + grid_logits = grid_features[..., 0] # assume the first element is the logits + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + next_logits = torch.full( + next_index.shape, -10000.0, dtype=dtype, device=device + ) + curr_points = extract_near_surface_volume_fn( + grid_logits.squeeze(0), mc_level + ) + curr_points += grid_logits.squeeze(0).abs() < 0.95 + + if octree_depth_now == resolutions[-1]: + expand_num = 0 + else: + expand_num = 1 + for i in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for i in range(2 - expand_num): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + + next_points = torch.stack(nidx, dim=1) + next_points = next_points * torch.tensor( + resolution, dtype=latents.dtype, device=device + ) + torch.tensor(bbox_min, dtype=latents.dtype, device=device) + + batch_features = [] + for start in tqdm( + range(0, next_points.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]", + disable=not enable_pbar, + ): + queries = next_points[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + features = geo_decoder( + queries=batch_queries.to(latents.dtype), latents=latents + ) + batch_features.append(features) + grid_features = torch.cat(batch_features, dim=1) + grid_logits = grid_features[..., 0:1] + next_logits[nidx] = grid_logits[0, ..., 0] + grid_logits = next_logits.unsqueeze(0) + grid_logits[grid_logits == -10000.0] = empty_value + + return grid_logits diff --git a/step1x3d_geometry/models/conditional_encoders/__init__.py b/step1x3d_geometry/models/conditional_encoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e31a3a5d954ca53b0837e238dc802345f62e0bf0 --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/__init__.py @@ -0,0 +1,6 @@ +from . import ( + dinov2_encoder, + dinov2_clip_encoder, + t5_encoder, + label_encoder, +) diff --git a/step1x3d_geometry/models/conditional_encoders/base.py b/step1x3d_geometry/models/conditional_encoders/base.py new file mode 100755 index 0000000000000000000000000000000000000000..0e61287e51aadc1be9c42aec35cbb78d74835abf --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/base.py @@ -0,0 +1,202 @@ +import random +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from dataclasses import dataclass +from torchvision.transforms import Normalize +from torchvision.transforms import InterpolationMode +from torchvision.transforms.transforms import _interpolation_modes_from_int + +from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import step1x3d_geometry +from step1x3d_geometry.utils.base import BaseModule +from step1x3d_geometry.utils.typing import * + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] + + +class BaseVisualEncoder(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path + ) + + encode_camera: bool = False # whether to encode camera + camera_embeds_type: str = "sincos" # the type of camera embeds + camera_embeds_dim: Optional[int] = None # the dimension of camera embeds + n_views: int = 1 # the number of views + + empty_embeds_ratio: float = 0.1 # the ratio of empty embeds + normalize_embeds: bool = False # whether to normalize the embeds + zero_uncond_embeds: bool = True + + cfg: Config + + def configure(self) -> None: + super().configure() + + if self.cfg.encode_camera: + self.distance = 1.0 + self.register_buffer( + "cameras", + torch.as_tensor( + [ + [ + [1, 0, 0, 0], + [0, 0, -1, -self.distance], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], # front to back + [ + [0, 0, 1, self.distance], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], # right to left + [ + [-1, 0, 0, 0], + [0, 0, 1, self.distance], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], # back to front + [ + [0, 0, -1, -self.distance], + [-1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], # left to right + ], + dtype=torch.float32, + ), + ) + + def encode_image( + self, + images: Iterable[Optional[ImageType]], + camera_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + raise NotImplementedError + + def encode_camera(self, c2ws: torch.Tensor): + if self.cfg.camera_embeds_type == "sincos": + assert ( + c2ws.shape[-1] == 4 and c2ws.shape[-2] == 4 + ), f"Invalid c2ws shape: {c2ws.shape}" + c2ws = c2ws.view(-1, 16) + return torch.cat([torch.sin(c2ws), torch.cos(c2ws)], dim=-1) + else: + raise NotImplementedError( + f"Unknown camera_embeds_type: {self.cfg.camera_embeds_type}" + ) + + def forward(self, batch): + assert ( + "image" in batch or "mvimages" in batch + ), "image or mvimages is required for visual embeds" + if batch["image"].dim() == 5: + bs = batch["image"].shape[0] * batch["image"].shape[1] + else: + bs = batch["image"].shape[0] + + if random.random() < self.cfg.empty_embeds_ratio: + if "image" in batch or "image_embeds" in batch: + visual_embeds = self.empty_image_embeds.repeat(bs, 1, 1) + elif "mvimages" in batch or "mvimage_embeds" in batch: + visual_embeds = self.empty_image_embeds.unsqueeze(1).repeat(bs, 1, 1, 1) + else: + # for visual inputs + if "image" in batch: + if self.cfg.encode_camera: + visual_embeds = self.encode_image( + batch["image"], cameras=batch["c2w"] + ) + else: + visual_embeds = self.encode_image(batch["image"]) + elif "mvimages" in batch: + n_views = batch["mvimages"].shape[1] + if self.cfg.encode_camera: + visual_embeds = self.encode_image( + batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:]), + cameras=batch["c2ws"], + ).view(bs, n_views, *self.empty_image_embeds.shape[-2:]) + else: + visual_embeds = self.encode_image( + batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:]) + ).view(bs, n_views, *self.empty_image_embeds.shape[-2:]) + + if self.cfg.normalize_embeds: # post-process the visual embeds + visual_embeds = visual_embeds / visual_embeds.norm(dim=-1, keepdim=True) + + return visual_embeds + + +class BaseCaptionEncoder(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path + ) + + text_max_length: int = 77 + + empty_embeds_ratio: float = 0.1 # the ratio of empty embeds + normalize_embeds: bool = False # whether to normalize the embeds + zero_uncond_embeds: bool = True + + cfg: Config + + def configure(self) -> None: + super().configure() + + def forward(self, batch, force_drop_ids=None): + assert "caption" in batch, "caption is required for caption embeds" + + bs = len(batch["label"]) + if random.random() < self.cfg.empty_embeds_ratio: + caption_embeds = self.empty_text_embeds.repeat(bs, 1, 1) + else: + caption_embeds = self.encode_text(batch["caption"]) + + if self.cfg.normalize_embeds: # post-process the label embeds + caption_embeds = caption_embeds / caption_embeds.norm(dim=-1, keepdim=True) + + return caption_embeds + + +class BaseLabelEncoder(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path + ) + + hidden_size: int = 1024 + + empty_embeds_ratio: float = 0.1 # the ratio of empty embeds + normalize_embeds: bool = False # whether to normalize the embeds + zero_uncond_embeds: bool = True + + cfg: Config + + def configure(self) -> None: + super().configure() + + def forward(self, batch, force_drop_ids=None): + assert "label" in batch, "label is required for label embeds" + + bs = len(batch["label"]) + if random.random() < self.cfg.empty_embeds_ratio: + label_embeds = self.empty_label_embeds.repeat(bs, 1, 1) + else: + label_embeds = self.encode_label(batch["label"]) + + if self.cfg.normalize_embeds: # post-process the label embeds + label_embeds = label_embeds / label_embeds.norm(dim=-1, keepdim=True) + + return label_embeds diff --git a/step1x3d_geometry/models/conditional_encoders/clip/modeling_clip.py b/step1x3d_geometry/models/conditional_encoders/clip/modeling_clip.py new file mode 100755 index 0000000000000000000000000000000000000000..b67639214cd4da0784176d24019cf5847ec80735 --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/clip/modeling_clip.py @@ -0,0 +1,1597 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CLIP model.""" + + +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.clip.configuration_clip import ( + CLIPConfig, + CLIPTextConfig, + CLIPVisionConfig, +) + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "CLIPConfig" +_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32" + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32" +_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0" + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/2021-03-07-clip.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy( + logits, torch.arange(len(logits), device=logits.device) + ) + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class CLIPVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CLIPTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CLIPOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`CLIPTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`CLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + ( + self[k] + if k not in ["text_model_output", "vision_model_output"] + else getattr(self, k).to_tuple() + ) + for k in self.keys() + ) + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class CLIPMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: CLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CLIPConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_( + module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor + ) + nn.init.normal_( + module.patch_embedding.weight, + std=module.config.initializer_range * factor, + ) + nn.init.normal_( + module.position_embedding.weight, + std=module.config.initializer_range * factor, + ) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = ( + (module.embed_dim**-0.5) + * ((2 * module.config.num_hidden_layers) ** -0.5) + * factor + ) + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) + * ((2 * module.config.num_hidden_layers) ** -0.5) + * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, CLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 + * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: CLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class CLIPModel(CLIPPreTrainedModel): + config_class = CLIPConfig + _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, self.projection_dim, bias=False + ) + self.text_projection = nn.Linear( + self.text_embed_dim, self.projection_dim, bias=False + ) + self.logit_scale = nn.Parameter( + torch.tensor(self.config.logit_scale_init_value) + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = ( + logits_per_image, + logits_per_text, + text_embeds, + image_embeds, + text_outputs, + vision_outputs, + ) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output). + """, + CLIP_START_DOCSTRING, +) +class CLIPTextModelWithProjection(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + + self.text_model = CLIPTextTransformer(config) + + self.text_projection = nn.Linear( + config.hidden_size, config.projection_dim, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CLIPTextModelOutput, config_class=CLIPTextConfig + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPTextModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection + + >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + text_embeds = self.text_projection(pooled_output) + + if not return_dict: + outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return CLIPTextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +@add_start_docstrings( + """ + CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output). + """, + CLIP_START_DOCSTRING, +) +class CLIPVisionModelWithProjection(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + + self.vision_model = CLIPVisionTransformer(config) + + self.visual_projection = nn.Linear( + config.hidden_size, config.projection_dim, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPVisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection + + >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> image_embeds = outputs.image_embeds + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + + image_embeds = self.visual_projection(pooled_output) + + if not return_dict: + outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return CLIPVisionModelOutput( + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + +@add_start_docstrings( + """ + CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """, + CLIP_START_DOCSTRING, +) +class CLIPForImageClassification(CLIPPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: CLIPConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.vision_model = CLIPVisionTransformer(config.vision_config) + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/step1x3d_geometry/models/conditional_encoders/clip/modeling_conditional_clip.py b/step1x3d_geometry/models/conditional_encoders/clip/modeling_conditional_clip.py new file mode 100755 index 0000000000000000000000000000000000000000..812bf95ba19f5640f21adf714a830ff7033453ce --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/clip/modeling_conditional_clip.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Reference: +# * transformers/models/dinov2/modeling_dinov2.py +# * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 +# * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 +"""PyTorch CLIP model.""" + +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from .modeling_clip import ( + CLIPConfig, + CLIPTextConfig, + CLIPVisionConfig, + CLIPEncoderLayer, + CLIPTextTransformer, + CLIPVisionTransformer, + CLIPModel, + CLIPVisionEmbeddings, + CLIPVisionModel, + CLIPOutput, + BaseModelOutput, + BaseModelOutputWithPooling, +) + + +class ModLN(nn.Module): + def __init__(self, inner_dim: int, mod_dim: int = 32): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(mod_dim, inner_dim * 2), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor, condition: torch.Tensor): + """ + x: [N, M, C_in], M: num of tokens + condition: [N, C_mod] + """ + shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) + return x * (1 + scale) + shift + + +class ConditionalCLIPVisionConfig(CLIPVisionConfig): + def __init__(self, modulation_dim: int = 32, *args, **kwargs): + super().__init__(*args, **kwargs) + self.modulation_dim = modulation_dim + + +class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: ConditionalCLIPVisionConfig) -> None: + super().__init__(config) + self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) + self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + condition: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + residual = hidden_states + + hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ConditionalCLIPEncoder(nn.Module): + def __init__(self, config: CLIPConfig) -> None: + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + ConditionalCLIPEncoderLayer(config) + for _ in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + condition: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + condition=condition, + output_attentions=output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + condition=condition, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): + def __init__(self, config: ConditionalCLIPVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = ConditionalCLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + condition=condition, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ConditionalCLIPVisionModel(CLIPVisionModel): + config_class = ConditionalCLIPVisionConfig + + def __init__(self, config: ConditionalCLIPVisionConfig): + super().__init__(config) + self.vision_model = ConditionalCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.vision_model( + pixel_values=pixel_values, + condition=condition, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ConditionalCLIPModel(CLIPModel): + config_class = CLIPConfig + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = ConditionalCLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, self.projection_dim, bias=False + ) + self.text_projection = nn.Linear( + self.text_embed_dim, self.projection_dim, bias=False + ) + self.logit_scale = nn.Parameter( + torch.tensor(self.config.logit_scale_init_value) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + condition=condition, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + condition: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + condition=condition, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = ( + logits_per_image, + logits_per_text, + text_embeds, + image_embeds, + text_outputs, + vision_outputs, + ) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) diff --git a/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_conditional_dinov2.py b/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_conditional_dinov2.py new file mode 100755 index 0000000000000000000000000000000000000000..cf45910552e85a8cbf81470503f0cffe6a4c8e15 --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_conditional_dinov2.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Reference: +# * transformers/models/dinov2/modeling_dinov2.py +# * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 +# * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 +"""PyTorch DINOv2 model.""" + +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from .modeling_dinov2 import ( + Dinov2Config, + Dinov2Layer, + Dinov2Model, + Dinov2Embeddings, + BaseModelOutput, + BaseModelOutputWithPooling, +) + + +class ModLN(nn.Module): + def __init__(self, inner_dim: int, mod_dim: int = 1024): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(mod_dim, inner_dim * 2), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor, condition: torch.Tensor): + """ + x: [N, M, C_in], M: num of tokens + condition: [N, C_mod] + """ + shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) + return x * (1 + scale) + shift + + +class ConditionalDinov2Config(Dinov2Config): + def __init__(self, modulation_dim: int = 1024, *args, **kwargs): + super().__init__(*args, **kwargs) + self.modulation_dim = modulation_dim + + +class ConditionalDinov2Layer(Dinov2Layer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: ConditionalDinov2Config) -> None: + super().__init__(config) + self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) + self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.mod_norm1( + self.norm1(hidden_states), condition + ), # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.mod_norm2(self.norm2(hidden_states), condition) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class ConditionalDinov2Encoder(nn.Module): + def __init__(self, config: ConditionalDinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + condition: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + condition, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + layer_head_mask, + condition, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ConditionalDinov2Model(Dinov2Model): + config_class = ConditionalDinov2Config + + def __init__(self, config: ConditionalDinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = ConditionalDinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + condition=condition, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_dinov2.py b/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_dinov2.py new file mode 100755 index 0000000000000000000000000000000000000000..deabab5735ae44278252e60baee361f350f769f1 --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/dinov2/modeling_dinov2.py @@ -0,0 +1,978 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DINOv2 model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin +from transformers.models.dinov2.configuration_dinov2 import Dinov2Config + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/dinov2-base", + # See all DINOv2 models at https://huggingface.co/models?filter=dinov2 +] + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + scale_factor=( + float(height / math.sqrt(num_positions)), + float(width / math.sqrt(num_positions)), + ), + mode="bicubic", + align_corners=False, + ).to(dtype=target_dtype) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), + self.mask_token.to(embeddings.dtype).unsqueeze(0), + embeddings, + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len( + heads + ) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter( + config.layerscale_value * torch.ones(config.hidden_size) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device + ) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1( + hidden_states + ), # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Dinov2Layer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [ + config.hidden_size for _ in range(config.num_hidden_layers + 1) + ] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape( + batch_size, height // patch_size, width // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) diff --git a/step1x3d_geometry/models/conditional_encoders/dinov2_clip_encoder.py b/step1x3d_geometry/models/conditional_encoders/dinov2_clip_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..38c6bd6e94fd1f6829d01a10c6f6c8bdb4c6808c --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/dinov2_clip_encoder.py @@ -0,0 +1,514 @@ +import random +import torch +from torch import nn +import numpy as np +import re +from einops import rearrange +from dataclasses import dataclass +from torchvision import transforms + +from diffusers.models.modeling_utils import ModelMixin +from transformers import CLIPTokenizer, CLIPImageProcessor +from transformers import AutoImageProcessor, AutoModel +from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * +from .clip.modeling_clip import CLIPModel +from .clip.modeling_conditional_clip import ConditionalCLIPModel +from .base import BaseVisualEncoder, ImageType +from .dinov2.modeling_dinov2 import Dinov2Model +from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model +from .dinov2_with_registers.modeling_dinov2_with_registers import ( + Dinov2WithRegistersModel, +) + +CLIP_IMAGE_SIZE = 224 + + +@dataclass +class CLIPEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + embeds: torch.FloatTensor = None + + +class DINOEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + + +@step1x3d_geometry.register("dinov2-clip-encoder") +class Dinov2CLIPEncoder(BaseVisualEncoder, ModelMixin): + + @dataclass + class Config(BaseVisualEncoder.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path for condition model + ) + pretrained_clip_name_or_path: Optional[str] = ( + None # the pretrained model name or path for clip + ) + pretrained_dino_name_or_path: Optional[str] = ( + None # the pretrained model name or path for dino + ) + pretrained_linear_proj: Optional[str] = None + freeze_modulation_clip: bool = False + freeze_modulation_dino: bool = False + enable_gradient_checkpointing: bool = False + image_size: int = CLIP_IMAGE_SIZE + fuse_type: str = "concat" + + dino_type: Optional[str] = None + clip_type: Optional[str] = None + kwargs: Optional[dict] = None + + cfg: Config + + def configure(self) -> None: + super().configure() + + # Load the CLIP model and processor + if not self.cfg.encode_camera: + if self.cfg.pretrained_clip_name_or_path is not None: + self.cfg.clip_type = f"openai/{self.cfg.pretrained_clip_name_or_path.split('openai--')[-1].split('/')[0]}" + self.clip_model: CLIPModel = CLIPModel.from_pretrained( + self.cfg.pretrained_clip_name_or_path + ) + else: + print("Loading CLIP model from openai/clip-vit-large-patch14") + self.dino_type = "openai/clip-vit-large-patch14" + self.clip_model: CLIPModel = CLIPModel( + config=ConditionalCLIPModel.config_class.from_pretrained( + "openai/clip-vit-large-patch14", + ) + ) + if self.cfg.pretrained_dino_name_or_path is not None: + self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}" + self.dino_model: Dinov2Model = AutoModel.from_pretrained( + self.cfg.pretrained_dino_name_or_path + ) + else: + if ( + self.cfg.pretrained_model_name_or_path is None + ): # default to load Dinov2-base model + assert ( + self.cfg.dino_type is not None + ), "The dino_type should be provided" + print(f"Loading Dinov2 model from {self.cfg.dino_type}") + if "reg" in self.cfg.dino_type: + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + self.cfg.dino_type, + ) + ) + ) + else: + self.dino_model: Dinov2Model = Dinov2Model( + config=Dinov2Model.config_class.from_pretrained( + self.dino_type, + ) + ) + elif "dinov2base" in self.cfg.pretrained_model_name_or_path: + print("Loading Dinov2 model from facebook/dinov2-base") + self.cfg.dino_type = "facebook/dinov2-base" + self.dino_model: Dinov2Model = Dinov2Model( + config=Dinov2Model.config_class.from_pretrained( + "facebook/dinov2-base", + ) + ) + elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path: + print( + "Loading Dinov2 model from facebook/dinov2-with-registers-base" + ) + self.cfg.dino_type = "facebook/dinov2-with-registers-base" + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + "facebook/dinov2-with-registers-base", + ) + ) + ) + elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path: + print( + "Loading Dinov2 model from facebook/dinov2-with-registers-large" + ) + self.cfg.dino_type = "facebook/dinov2-with-registers-large" + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + "facebook/dinov2-with-registers-large", + ) + ) + ) + else: + raise ValueError( + f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}" + ) + else: + # clip + conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( + self.cfg.pretrained_clip_name_or_path, + ) + conditional_clip_config.vision_config.modulation_dim = ( + self.cfg.camera_embeds_dim + ) + self.clip_model: CLIPModel = ConditionalCLIPModel.from_pretrained( + self.cfg.pretrained_clip_name_or_path, + vision_config=conditional_clip_config.vision_config, + ) + + # dino + conditional_vit_config = ( + ConditionalDinov2Model.config_class.from_pretrained( + self.cfg.pretrained_dino_name_or_path, + ) + ) + conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim + self.dino_model: ConditionalDinov2Model = ( + ConditionalDinov2Model.from_pretrained( + self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config + ) + ) + + self.image_preprocess_clip = CLIPImageProcessor() + self.image_preprocess_dino = AutoImageProcessor.from_pretrained( + self.cfg.dino_type + if self.cfg.pretrained_dino_name_or_path is None + else self.cfg.pretrained_dino_name_or_path + ) + self.transform_clip = transforms.Compose( + [ + transforms.Resize( + CLIP_IMAGE_SIZE, + transforms.InterpolationMode.BICUBIC, + antialias=True, + ), # clip is CLIP_IMAGE_SIZE + transforms.CenterCrop(CLIP_IMAGE_SIZE), # crop a square. + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + self.transform_dino = transforms.Compose( + [ + transforms.Resize( + self.cfg.image_size, + transforms.InterpolationMode.BICUBIC, + antialias=True, + ), + transforms.CenterCrop(self.cfg.image_size), # crop a square + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + + if self.cfg.enable_gradient_checkpointing: + self.dino_model.encoder.gradient_checkpointing = True + + if self.cfg.zero_uncond_embeds: + image_size = max(self.cfg.image_size, self.cfg.image_size) + self.empty_image_embeds_dino = torch.zeros( + (self.cfg.n_views, (image_size // 14) ** 2 + 1, 1024) + ).detach() + self.empty_image_embeds_clip = torch.zeros( + (self.cfg.n_views, (CLIP_IMAGE_SIZE // 14) ** 2 + 1, 1024) + ).detach() + if self.cfg.fuse_type == "concat": + self.empty_image_embeds = torch.cat( + [self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1 + ) + else: + raise ValueError + else: + if self.cfg.encode_camera: + self.empty_image_embeds_dino = self.encode_image_dino( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ), + self.cameras[: self.cfg.n_views], + ).detach() + self.empty_image_embeds_clip = self.encode_image_clip( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ), + self.cameras[: self.cfg.n_views], + ).detach() + else: + self.empty_image_embeds_dino = self.encode_image_dino( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ) + ).detach() + self.empty_image_embeds_clip = self.encode_image_clip( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ) + ).detach() + self.empty_image_embeds_clip, self.empty_image_embeds_dino = ( + self.align_clip_dino( + self.empty_image_embeds_clip, self.empty_image_embeds_dino + ) + ) + self.empty_image_embeds = torch.cat( + [self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1 + ) + + # Freeze the clip model parameters + self.clip_model.eval() + for k, p in self.clip_model.named_parameters(): + ks = k.split(".") + if ( + "mod_norm1" in ks + or "mod_norm2" in ks + and not self.cfg.freeze_modulation_clip + ): + p.requires_grad_(not self.cfg.freeze_modulation_clip) + else: + p.requires_grad_(False) + + # freeze the dino model parameters + self.dino_model.eval() + for k, p in self.dino_model.named_parameters(): + ks = k.split(".") + if ( + "mod_norm1" in ks + or "mod_norm2" in ks + and not self.cfg.freeze_modulation_dino + ): + p.requires_grad_(not self.cfg.freeze_modulation_dino) + else: + p.requires_grad_(False) + + # add a linear projection layer to project the dino embeddings to the same dimension as clip embeddings + if ( + self.clip_model.config.vision_config.hidden_size + != self.dino_model.config.hidden_size + ): + self.linear_proj = nn.Linear( + self.clip_model.config.vision_config.hidden_size, + self.dino_model.config.vision_config.hidden_size, + bias=False, + ) + else: + self.linear_proj = nn.Identity() + + if self.cfg.pretrained_model_name_or_path is not None: + print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, map_location="cpu" + )["state_dict"] + pretrained_model_ckpt = {} + for k, v in ckpt.items(): + if k.startswith("condition."): + pretrained_model_ckpt[k.replace("condition.", "")] = v + self.load_state_dict(pretrained_model_ckpt, strict=True) + + def encode_image_clip( + self, + images: Iterable[Optional[ImageType]], + cameras: Optional[torch.Tensor] = None, + force_none_camera_embeds: bool = False, + return_dict: bool = False, + **kwargs, + ) -> torch.FloatTensor: + camera_embeds = None + if isinstance(images, (np.ndarray, torch.Tensor)): # for training process + assert ( + images.min() >= 0.0 and images.max() <= 1.0 + ), "The pixel values should be in the range of [0, 1]" + if self.cfg.encode_camera: + assert cameras is not None, "The cameras should be provided" + camera_embeds = self.encode_camera(cameras) + pixel_values = self.transform_clip(images.permute(0, 3, 1, 2)) + else: # for inference process + if self.cfg.encode_camera: + if cameras is None: + bs = len(images) // self.cfg.n_views + cameras = ( + self.cameras[: self.cfg.n_views] + .repeat(bs, 1, 1) + .to(self.clip_model.device) + ) + camera_embeds = self.encode_camera(cameras) + pixel_values = self.image_preprocess_clip.preprocess( + images, + return_tensors="pt", + do_rescale=True, + do_resize=True, + size=CLIP_IMAGE_SIZE, + crop_size=CLIP_IMAGE_SIZE, + ).pixel_values + + if force_none_camera_embeds: + camera_embeds = None + + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + if camera_embeds is not None: + camera_embeds = camera_embeds.unsqueeze(1) + + if self.cfg.encode_camera and camera_embeds is not None: + vision_outputs = self.clip_model.vision_model( + pixel_values=rearrange( + pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W" + ), + condition=rearrange(camera_embeds, "B N C -> (B N) C"), + ) + + else: + vision_outputs = self.clip_model.vision_model( + pixel_values=rearrange( + pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W" + ), + ) + + if return_dict: + # clip + pooler_output = vision_outputs[1] # pooled_output + image_features = self.clip_model.visual_projection(pooler_output) + clip_embeds = vision_outputs.last_hidden_state + + clip_embeds_dict = CLIPEmbedOutput( + last_hidden_state=clip_embeds, + pooler_output=pooler_output, + embeds=image_features, + ) + + return clip_embeds_dict + else: + return vision_outputs.last_hidden_state + + def encode_image_dino( + self, + images: Iterable[Optional[ImageType]], + cameras: Optional[torch.Tensor] = None, + force_none_camera_embeds: bool = False, + return_dict: bool = False, + **kwargs, + ) -> torch.FloatTensor: + camera_embeds = None + if isinstance(images, (np.ndarray, torch.Tensor)): # for training process + assert ( + images.min() >= 0.0 and images.max() <= 1.0 + ), "The pixel values should be in the range of [0, 1]" + if self.cfg.encode_camera: + assert cameras is not None, "The cameras should be provided" + camera_embeds = self.encode_camera(cameras) + pixel_values = self.transform_dino(images.permute(0, 3, 1, 2)) + else: # for inference process + if self.cfg.encode_camera: + if cameras is None: + bs = len(images) // self.cfg.n_views + cameras = ( + self.cameras[: self.cfg.n_views] + .repeat(bs, 1, 1) + .to(self.dino_model.device) + ) + camera_embeds = self.encode_camera(cameras) + pixel_values = self.image_preprocess_dino.preprocess( + images, + return_tensors="pt", + do_rescale=True, + do_resize=True, + size=self.cfg.image_size, + crop_size=self.cfg.image_size, + ).pixel_values + + if force_none_camera_embeds: + camera_embeds = None + + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + if camera_embeds is not None: + camera_embeds = camera_embeds.unsqueeze(1) + + if self.cfg.encode_camera and camera_embeds is not None: + vision_outputs = self.dino_model( + rearrange( + pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" + ), + condition=rearrange(camera_embeds, "B N C -> (B N) C"), + ) + else: + vision_outputs = self.dino_model( + rearrange( + pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" + ), + ) + + if return_dict: + # dino + dino_embeds_dict = DINOEmbedOutput( + last_hidden_state=vision_outputs.last_hidden_state, + pooler_output=vision_outputs.pooler_output, + ) + return dino_embeds_dict + else: + return vision_outputs.last_hidden_state + + def align_clip_dino(self, clip_embeds, dino_embeds): + if ( + clip_embeds.shape[-2] != dino_embeds.shape[-2] + ): # different shape, interpolate the clip embeddings to the same shape as dino embeddings + assert ( + clip_embeds.shape[-2] == (self.cfg.image_size // 14) ** 2 + 1 + ), "The clip embeddings should have the shape of (n_views, (image_size // 14) ** 2 + 1, 1024)" + clip_embeds_patch_tokens = clip_embeds[:, 1:].view( + clip_embeds.shape[0], + self.cfg.image_size // 14, + self.cfg.image_size // 14, + 1024, + ) + clip_embeds_patch_tokens = ( + torch.nn.functional.interpolate( + clip_embeds_patch_tokens.permute(0, 3, 1, 2), + size=(self.cfg.image_size // 14, self.cfg.image_size // 14), + mode="bilinear", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .view(clip_embeds.shape[0], -1, 1024) + ) + clip_embeds = torch.cat( + [clip_embeds[:, :1], clip_embeds_patch_tokens], dim=1 + ) + return clip_embeds, dino_embeds + + def encode_image( + self, + images: Iterable[Optional[ImageType]], + cameras: Optional[torch.Tensor] = None, + force_none_camera_embeds: bool = False, + return_dict: bool = False, + **kwargs, + ) -> torch.FloatTensor: + clip_embeds = self.encode_image_clip(images, cameras) + dino_embeds = self.encode_image_dino(images, cameras) + if ( + self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel" + ): # x_norm_clstoken, x_norm_regtokens, x_norm_patchtokens + dino_embeds = torch.cat( + [ + dino_embeds[:, :1], + dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :], + ], + dim=1, + ) + + clip_embeds = self.linear_proj(clip_embeds) # bs, 257, 1024 + + if self.cfg.fuse_type == "concat": + visual_embeds = torch.cat([dino_embeds, clip_embeds], dim=1) + # elif self.cfg.fuse_type == 'add': + # clip_embeds, dino_embeds = self.align_clip_dino(clip_embeds, dino_embeds) + else: + raise ValueError + + return visual_embeds diff --git a/step1x3d_geometry/models/conditional_encoders/dinov2_encoder.py b/step1x3d_geometry/models/conditional_encoders/dinov2_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..e499783bd301e9c8f9c86515e50114c575ba675a --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/dinov2_encoder.py @@ -0,0 +1,296 @@ +import random +import torch +from torch import nn +import numpy as np +import re +from einops import rearrange +from dataclasses import dataclass +from torchvision import transforms + +from diffusers.models.modeling_utils import ModelMixin +from transformers import AutoImageProcessor, AutoModel +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * +from .base import BaseVisualEncoder, ImageType +from .dinov2.modeling_dinov2 import Dinov2Model +from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model +from .dinov2_with_registers.modeling_dinov2_with_registers import ( + Dinov2WithRegistersModel, +) + + +class DINOEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + + +@step1x3d_geometry.register("dinov2-encoder") +class Dinov2Encoder(BaseVisualEncoder, ModelMixin): + + @dataclass + class Config(BaseVisualEncoder.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path for condition model + ) + pretrained_dino_name_or_path: Optional[str] = ( + None # the pretrained model name or path for dino + ) + freeze_modulation_dino: bool = False + enable_gradient_checkpointing: bool = False + image_size: int = 224 + dino_type: Optional[str] = None + kwargs: Optional[dict] = None + + cfg: Config + + def configure(self) -> None: + super().configure() + + # Load the DINOV2 model and processor + if not self.cfg.encode_camera: + if self.cfg.pretrained_dino_name_or_path is not None: + self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}" + if self.cfg.kwargs is not None: + self.dino_model: Dinov2Model = AutoModel.from_pretrained( + self.cfg.pretrained_dino_name_or_path, **self.cfg.kwargs + ) + else: + self.dino_model: Dinov2Model = AutoModel.from_pretrained( + self.cfg.pretrained_dino_name_or_path + ) + else: + if ( + self.cfg.pretrained_model_name_or_path is None + ): # default to load Dinov2-base model + assert ( + self.cfg.dino_type is not None + ), "The dino_type should be provided" + print(f"Loading Dinov2 model from {self.cfg.dino_type}") + if "reg" in self.cfg.dino_type: + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + self.cfg.dino_type, + ) + ) + ) + else: + self.dino_model: Dinov2Model = Dinov2Model( + config=Dinov2Model.config_class.from_pretrained( + self.dino_type, + ) + ) + elif "dinov2base" in self.cfg.pretrained_model_name_or_path: + print("Loading Dinov2 model from facebook/dinov2-base") + self.cfg.dino_type = "facebook/dinov2-base" + self.dino_model: Dinov2Model = Dinov2Model( + config=Dinov2Model.config_class.from_pretrained( + "facebook/dinov2-base", + ) + ) + elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path: + print( + "Loading Dinov2 model from facebook/dinov2-with-registers-base" + ) + self.cfg.dino_type = "facebook/dinov2-with-registers-base" + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + "facebook/dinov2-with-registers-base", + ) + ) + ) + elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path: + print( + "Loading Dinov2 model from facebook/dinov2-with-registers-large" + ) + self.cfg.dino_type = "facebook/dinov2-with-registers-large" + self.dino_model: Dinov2WithRegistersModel = ( + Dinov2WithRegistersModel( + config=Dinov2WithRegistersModel.config_class.from_pretrained( + "facebook/dinov2-with-registers-large", + ) + ) + ) + else: + raise ValueError( + f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}" + ) + else: + # dino + conditional_vit_config = ( + ConditionalDinov2Model.config_class.from_pretrained( + self.cfg.pretrained_dino_name_or_path, + ) + ) + conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim + self.dino_model: ConditionalDinov2Model = ( + ConditionalDinov2Model.from_pretrained( + self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config + ) + ) + + self.image_preprocess_dino = AutoImageProcessor.from_pretrained( + self.cfg.dino_type + if self.cfg.pretrained_dino_name_or_path is None + else self.cfg.pretrained_dino_name_or_path + ) + self.transform_dino = transforms.Compose( + [ + transforms.Resize( + self.cfg.image_size, + transforms.InterpolationMode.BICUBIC, + antialias=True, + ), + transforms.CenterCrop( + self.cfg.image_size + ), # crop a (image_size, image_size) square + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + + if self.cfg.enable_gradient_checkpointing: + self.dino_model.encoder.gradient_checkpointing = True + + if self.cfg.zero_uncond_embeds: + self.empty_image_embeds = torch.zeros( + ( + self.cfg.n_views, + (self.cfg.image_size // 14) ** 2 + 1, + self.dino_model.config.hidden_size, + ) + ).detach() + else: + if self.cfg.encode_camera: + self.empty_image_embeds = self.encode_image_dino( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ), + self.cameras[: self.cfg.n_views], + ).detach() + else: + self.empty_image_embeds = self.encode_image_dino( + torch.zeros( + self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 + ) + ).detach() + + # freeze the dino model parameters + self.dino_model.eval() + for k, p in self.dino_model.named_parameters(): + ks = k.split(".") + if ( + "mod_norm1" in ks + or "mod_norm2" in ks + and not self.cfg.freeze_modulation_dino + ): + p.requires_grad_(not self.cfg.freeze_modulation_dino) + else: + p.requires_grad_(False) + + # load pretrained_model_name_or_path + if self.cfg.pretrained_model_name_or_path is not None: + print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, map_location="cpu" + )["state_dict"] + pretrained_model_ckpt = {} + for k, v in ckpt.items(): + if k.startswith("visual_condition."): + pretrained_model_ckpt[k.replace("visual_condition.", "")] = v + self.load_state_dict(pretrained_model_ckpt, strict=True) + + def encode_image_dino( + self, + images: Iterable[Optional[ImageType]], + cameras: Optional[torch.Tensor] = None, + force_none_camera_embeds: bool = False, + return_dict: bool = False, + **kwargs, + ) -> torch.FloatTensor: + camera_embeds = None + if isinstance(images, (np.ndarray, torch.Tensor)): # for training process + assert ( + images.min() >= 0.0 and images.max() <= 1.0 + ), "The pixel values should be in the range of [0, 1]" + if self.cfg.encode_camera: + assert cameras is not None, "The cameras should be provided" + camera_embeds = self.encode_camera(cameras) + pixel_values = self.transform_dino(images.permute(0, 3, 1, 2)) + else: # for inference process + if self.cfg.encode_camera: + if cameras is None: + bs = len(images) // self.cfg.n_views + cameras = ( + self.cameras[: self.cfg.n_views] + .repeat(bs, 1, 1) + .to(self.dino_model.device) + ) + camera_embeds = self.encode_camera(cameras) + pixel_values = self.image_preprocess_dino.preprocess( + images, + return_tensors="pt", + do_rescale=True, + do_resize=True, + size=self.cfg.image_size, + crop_size=self.cfg.image_size, + ).pixel_values + + if force_none_camera_embeds: + camera_embeds = None + + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + if camera_embeds is not None: + camera_embeds = camera_embeds.unsqueeze(1) + + if self.cfg.encode_camera and camera_embeds is not None: + vision_outputs = self.dino_model( + rearrange( + pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" + ), + condition=rearrange(camera_embeds, "B N C -> (B N) C"), + ) + else: + vision_outputs = self.dino_model( + rearrange( + pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" + ), + ) + + if return_dict: + # dino + dino_embeds_dict = DINOEmbedOutput( + last_hidden_state=vision_outputs.last_hidden_state, + pooler_output=vision_outputs.pooler_output, + ) + return dino_embeds_dict + else: + return vision_outputs.last_hidden_state + + def encode_image( + self, + images: Iterable[Optional[ImageType]], + cameras: Optional[torch.Tensor] = None, + force_none_camera_embeds: bool = False, + return_dict: bool = False, + **kwargs, + ) -> torch.FloatTensor: + dino_embeds = self.encode_image_dino(images, cameras) + if ( + self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel" + ): # x_norm_clstoken, x_norm_regtokens, x_norm_patchtokens + dino_embeds = torch.cat( + [ + dino_embeds[:, :1], + dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :], + ], + dim=1, + ) + return dino_embeds diff --git a/step1x3d_geometry/models/conditional_encoders/dinov2_with_registers/modeling_dinov2_with_registers.py b/step1x3d_geometry/models/conditional_encoders/dinov2_with_registers/modeling_dinov2_with_registers.py new file mode 100755 index 0000000000000000000000000000000000000000..255051d00b0d376d6fdf775223305f466f3f85db --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -0,0 +1,1088 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from transformers.utils.backbone_utils import BackboneMixin +from transformers.models.dinov2_with_registers.configuration_dinov2_with_registers import ( + Dinov2WithRegistersConfig, +) + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" + +# General docstring +_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig" + + +class Dinov2WithRegistersPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class Dinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter( + torch.zeros(1, config.num_register_tokens, config.hidden_size) + ) + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility + with the original implementation. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py + - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # Skip interpolation for matching dimensions (unless tracing) + if ( + not torch.jit.is_tracing() + and num_patches == num_positions + and height == width + ): + return self.position_embeddings + + # Handle class token and patch embeddings separately + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + + # Calculate new dimensions + height = height // self.config.patch_size + width = width // self.config.patch_size + + # Reshape for interpolation + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Store original dtype for restoration after interpolation + target_dtype = patch_pos_embed.dtype + + # Interpolate at float32 precision + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + size=( + torch_int(height), + torch_int(width), + ), # Explicit size instead of scale_factor + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + # Validate output dimensions if not tracing + if not torch.jit.is_tracing(): + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + + # Reshape back to original format + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + # Combine class and patch embeddings + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), + self.mask_token.to(embeddings.dtype).unsqueeze(0), + embeddings, + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + # add register tokens + embeddings = torch.cat( + ( + embeddings[:, :1], + self.register_tokens.expand(embeddings.shape[0], -1, -1), + embeddings[:, 1:], + ), + dim=1, + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2WithRegistersSelfAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + +class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + +class Dinov2WithRegistersSelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Dinov2WithRegistersAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.attention = Dinov2WithRegistersSelfAttention(config) + self.output = Dinov2WithRegistersSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len( + heads + ) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention = Dinov2WithRegistersSdpaSelfAttention(config) + + +class Dinov2WithRegistersLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter( + config.layerscale_value * torch.ones(config.hidden_size) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device + ) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Dinov2WithRegistersDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2WithRegistersMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2WithRegistersSwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { + "eager": Dinov2WithRegistersAttention, + "sdpa": Dinov2WithRegistersSdpaAttention, +} + + +class Dinov2WithRegistersLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[ + config._attn_implementation + ](config) + self.layer_scale1 = Dinov2WithRegistersLayerScale(config) + self.drop_path = ( + Dinov2WithRegistersDropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2WithRegistersSwiGLUFFN(config) + else: + self.mlp = Dinov2WithRegistersMLP(config) + self.layer_scale2 = Dinov2WithRegistersLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1( + hidden_states + ), # in Dinov2WithRegisters, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2WithRegisters, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class Dinov2WithRegistersEncoder(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2WithRegistersConfig + base_model_prefix = "dinov2_with_registers" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] + _supports_sdpa = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + + +DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2_with_registers = Dinov2WithRegistersModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.dinov2_with_registers( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + self.num_features = [ + config.hidden_size for _ in range(config.num_hidden_layers + 1) + ] + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.num_register_tokens = config.num_register_tokens + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + Returns: + + Examples: + + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape( + batch_size, height // patch_size, width // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "Dinov2WithRegistersPreTrainedModel", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersBackbone", +] diff --git a/step1x3d_geometry/models/conditional_encoders/label_encoder.py b/step1x3d_geometry/models/conditional_encoders/label_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..113c25861ae43560b989df8fdae7274803f99d6e --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/label_encoder.py @@ -0,0 +1,167 @@ +import random +import torch +from torch import nn +import numpy as np +import re +from einops import rearrange +from dataclasses import dataclass +from torchvision import transforms +from diffusers.models.modeling_utils import ModelMixin + +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.misc import get_device + +from .base import BaseLabelEncoder + +DEFAULT_POSE = 0 # "unknown", "t-pose", "a-pose", uncond +NUM_POSE_CLASSES = 3 +POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3} + +DEFAULT_SYMMETRY_TYPE = 0 # "asymmetry", "x", uncond +NUM_SYMMETRY_TYPE_CLASSES = 2 +SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2} + +DEFAULT_GEOMETRY_QUALITY = 0 # "normal", "smooth", "sharp", uncond, +NUM_GEOMETRY_QUALITY_CLASSES = 3 +GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3} + + +@step1x3d_geometry.register("label-encoder") +class LabelEncoder(BaseLabelEncoder, ModelMixin): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + """ + + def configure(self) -> None: + super().configure() + + if self.cfg.zero_uncond_embeds: + self.embedding_table_tpose = nn.Embedding( + NUM_POSE_CLASSES, self.cfg.hidden_size + ) + self.embedding_table_symmetry_type = nn.Embedding( + NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size + ) + self.embedding_table_geometry_quality = nn.Embedding( + NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size + ) + else: + self.embedding_table_tpose = nn.Embedding( + NUM_POSE_CLASSES + 1, self.cfg.hidden_size + ) + self.embedding_table_symmetry_type = nn.Embedding( + NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size + ) + self.embedding_table_geometry_quality = nn.Embedding( + NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size + ) + + if self.cfg.zero_uncond_embeds: + self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach() + else: + self.empty_label_embeds = ( + self.encode_label( # the last class label is for the uncond + [{"pose": "", "symetry": "", "geometry_type": ""}] + ).detach() + ) + + # load pretrained_model_name_or_path + if self.cfg.pretrained_model_name_or_path is not None: + print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, map_location="cpu" + )["state_dict"] + pretrained_model_ckpt = {} + for k, v in ckpt.items(): + if k.startswith("label_condition."): + pretrained_model_ckpt[k.replace("label_condition.", "")] = v + self.load_state_dict(pretrained_model_ckpt, strict=True) + + def encode_label(self, labels: List[dict]) -> torch.FloatTensor: + tpose_label_embeds = [] + symmetry_type_label_embeds = [] + geometry_quality_label_embeds = [] + + for label in labels: + if "pose" in label.keys(): + if label["pose"] is None or label["pose"] == "": + tpose_label_embeds.append( + torch.zeros(self.cfg.hidden_size).detach().to(get_device()) + ) + else: + tpose_label_embeds.append( + self.embedding_table_symmetry_type( + torch.tensor(POSE_MAPPING[label["pose"][0]]).to( + get_device() + ) + ) + ) + else: + tpose_label_embeds.append( + self.embedding_table_tpose( + torch.tensor(DEFAULT_POSE).to(get_device()) + ) + ) + + if "symmetry" in label.keys(): + if label["symmetry"] is None or label["symmetry"] == "": + symmetry_type_label_embeds.append( + torch.zeros(self.cfg.hidden_size).detach().to(get_device()) + ) + else: + symmetry_type_label_embeds.append( + self.embedding_table_symmetry_type( + torch.tensor( + SYMMETRY_TYPE_MAPPING[label["symmetry"][0]] + ).to(get_device()) + ) + ) + else: + symmetry_type_label_embeds.append( + self.embedding_table_symmetry_type( + torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device()) + ) + ) + + if "geometry_type" in label.keys(): + if label["geometry_type"] is None or label["geometry_type"] == "": + geometry_quality_label_embeds.append( + torch.zeros(self.cfg.hidden_size).detach().to(get_device()) + ) + else: + geometry_quality_label_embeds.append( + self.embedding_table_geometry_quality( + torch.tensor( + GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]] + ).to(get_device()) + ) + ) + else: + geometry_quality_label_embeds.append( + self.embedding_table_geometry_quality( + torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device()) + ) + ) + + tpose_label_embeds = torch.stack(tpose_label_embeds) + symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds) + geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds) + + label_embeds = torch.stack( + [ + tpose_label_embeds, + symmetry_type_label_embeds, + geometry_quality_label_embeds, + ], + dim=1, + ).to(self.dtype) + + return label_embeds diff --git a/step1x3d_geometry/models/conditional_encoders/t5_encoder.py b/step1x3d_geometry/models/conditional_encoders/t5_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..b022daef8eb67985aedbf4dcbc32c092f9b7f4c2 --- /dev/null +++ b/step1x3d_geometry/models/conditional_encoders/t5_encoder.py @@ -0,0 +1,271 @@ +import random +import torch +from torch import nn +import numpy as np +import re +import urllib.parse as ul +from bs4 import BeautifulSoup +from einops import rearrange +from dataclasses import dataclass +from torchvision import transforms +from diffusers.models.modeling_utils import ModelMixin + +from transformers import AutoImageProcessor, AutoModel +from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * + +from .base import BaseCaptionEncoder + +bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + "\)" + + "\(" + + "\]" + + "\[" + + "\}" + + "\{" + + "\|" + + "\\" + + "\/" + + "\*" + + r"]{1,}" +) # noqa + + +@step1x3d_geometry.register("t5-encoder") +class T5Encoder(BaseCaptionEncoder, ModelMixin): + + @dataclass + class Config(BaseCaptionEncoder.Config): + pretrained_model_name_or_path: Optional[str] = ( + None # the pretrained model name or path for condition model + ) + pretrained_t5_name_or_path: Optional[str] = ( + None # the pretrained model name or path for T5 + ) + preprocessing_text: bool = False + text_max_length: int = 77 + t5_type: Optional[str] = None + + cfg: Config + + def configure(self) -> None: + super().configure() + + # Load the T5 model and tokenizer + if self.cfg.pretrained_t5_name_or_path is not None: + self.cfg.t5_type = f"google-t5/{self.cfg.pretrained_t5_name_or_path.split('google-t5--')[-1].split('/')[0]}" + self.tokenizer = T5Tokenizer.from_pretrained( + self.cfg.pretrained_t5_name_or_path + ) + self.text_model = T5EncoderModel.from_pretrained( + self.cfg.pretrained_t5_name_or_path, torch_dtype=torch.bfloat16 + ) + else: + if ( + self.cfg.pretrained_model_name_or_path is None + ): # default to load t5-base model + assert self.cfg.t5_type is not None, "The t5_type should be provided" + print(f"Loading T5 model from {self.cfg.t5_type}") + self.text_model = T5EncoderModel( + config=T5EncoderModel.config_class.from_pretrained( + self.cfg.t5_type, + ) + ).to(torch.bfloat16) + elif "t5small" in self.cfg.pretrained_model_name_or_path: + print("Loading Dinov2 model from google-t5/t5-small") + self.cfg.t5_type = "google-t5/t5-small" + self.text_model = T5EncoderModel.from_pretrained( + self.cfg.t5_type, torch_dtype=torch.bfloat16 + ) + elif "t5base" in self.cfg.pretrained_model_name_or_path: + print("Loading Dinov2 model from google-t5/t5-base") + self.cfg.t5_type = "google-t5/t5-base" + self.text_model = T5EncoderModel.from_pretrained( + self.cfg.t5_type, torch_dtype=torch.bfloat16 + ) + else: + raise ValueError( + f"Unknown T5 model: {self.cfg.pretrained_model_name_or_path}" + ) + self.tokenizer = T5Tokenizer.from_pretrained(self.cfg.t5_type) + + # Set the empty image/text embeds + if self.cfg.zero_uncond_embeds: + self.empty_text_embeds = torch.zeros( + (1, self.cfg.text_max_length, self.text_model.config.hidden_size) + ).detach() + else: + self.empty_text_embeds = self.encode_text([""]).detach() + + # load pretrained_model_name_or_path + if self.cfg.pretrained_model_name_or_path is not None: + print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, map_location="cpu" + )["state_dict"] + pretrained_model_ckpt = {} + for k, v in ckpt.items(): + if k.startswith("caption_condition."): + pretrained_model_ckpt[k.replace("caption_condition.", "")] = v + self.load_state_dict(pretrained_model_ckpt, strict=True) + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub( + r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption + ) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub( + bad_punct_regex, r" ", caption + ) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub( + r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption + ) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub( + r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption + ) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def text_preprocessing(self, text): + if self.cfg.preprocessing_text: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + def encode_text(self, texts: List[str]) -> torch.FloatTensor: + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.cfg.text_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"] # N x 77 + text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"] + + with torch.no_grad(): + label_embeds = self.text_model( + input_ids=text_tokens_and_mask["input_ids"].to(self.text_model.device), + attention_mask=text_tokens_and_mask["attention_mask"].to( + self.text_model.device + ), + )["last_hidden_state"].detach() + + return label_embeds diff --git a/step1x3d_geometry/models/pipelines/pipeline.py b/step1x3d_geometry/models/pipelines/pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..c8bd4b85572e1455c5d793b1ef62dc8f2544f26e --- /dev/null +++ b/step1x3d_geometry/models/pipelines/pipeline.py @@ -0,0 +1,513 @@ +# Some parts of this file are refer to Hugging Face Diffusers library. +import os +import json +import warnings +from typing import Callable, List, Optional, Union, Dict, Any +import PIL.Image +import trimesh +import rembg +import torch +import numpy as np +from huggingface_hub import hf_hub_download + +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.loaders import ( + FluxIPAdapterMixin, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +) +from .pipeline_utils import ( + TransformerDiffusionMixin, + preprocess_image, + retrieve_timesteps, + remove_floater, + remove_degenerate_face, + reduce_face, + smart_load_model, +) +from transformers import ( + BitImageProcessor, +) + +import step1x3d_geometry +from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult +from step1x3d_geometry.utils.config import ExperimentConfig, load_config +from ..autoencoders.michelangelo_autoencoder import MichelangeloAutoencoder +from ..conditional_encoders.dinov2_encoder import Dinov2Encoder +from ..conditional_encoders.t5_encoder import T5Encoder +from ..conditional_encoders.label_encoder import LabelEncoder +from ..transformers.flux_transformer_1d import FluxDenoiser + + +class Step1X3DGeometryPipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor`): + List of PIL images or a tensor representing the input images. + meshes (`List[trimesh.Trimesh]` or `np.ndarray`) + List of denoised trimesh meshes of length `batch_size` or a tuple of NumPy array with shape `((vertices, 3), (faces, 3)) of length `batch_size``. + """ + + image: PIL.Image.Image + mesh: Union[trimesh.Trimesh, MeshExtractResult, np.ndarray] + + +class Step1X3DGeometryPipeline( + DiffusionPipeline, FromSingleFileMixin, TransformerDiffusionMixin +): + """ + Step1X-3D Geometry Pipeline, generate high-quality meshes conditioned on image/caption/label inputs + + Args: + scheduler (FlowMatchEulerDiscreteScheduler): + The diffusion scheduler controlling the denoising process + vae (MichelangeloAutoencoder): + Variational Autoencoder for latent space compression/reconstruction + transformer (FluxDenoiser): + Transformer-based denoising model + visual_encoder (Dinov2Encoder): + Pretrained visual encoder for image feature extraction + caption_encoder (T5Encoder): + Text encoder for processing natural language captions + label_encoder (LabelEncoder): + Auxiliary text encoder for label conditioning + visual_eature_extractor (BitImageProcessor): + Preprocessor for input images + + Note: + - CPU offloading sequence: visual_encoder → caption_encoder → label_encoder → transformer → vae + - Optional components: visual_encoder, visual_eature_extractor, caption_encoder, label_encoder + """ + + model_cpu_offload_seq = ( + "visual_encoder->caption_encoder->label_encoder->transformer->vae" + ) + _optional_components = [ + "visual_encoder", + "visual_eature_extractor", + "caption_encoder", + "label_encoder", + ] + + @classmethod + def from_pretrained(cls, model_path, subfolder='.', **kwargs): + local_model_path = smart_load_model(model_path, subfolder) + return super().from_pretrained(local_model_path, **kwargs) + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: MichelangeloAutoencoder, + transformer: FluxDenoiser, + visual_encoder: Dinov2Encoder, + caption_encoder: T5Encoder, + label_encoder: LabelEncoder, + visual_eature_extractor: BitImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + transformer=transformer, + scheduler=scheduler, + visual_encoder=visual_encoder, + caption_encoder=caption_encoder, + label_encoder=label_encoder, + visual_eature_extractor=visual_eature_extractor, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def check_inputs( + self, + image, + ): + r""" + Check if the inputs are valid. Raise an error if not. + """ + if isinstance(image, str): + assert os.path.isfile(image) or image.startswith( + "http" + ), "Input image must be a valid URL or a file path." + elif isinstance(image, (torch.Tensor, PIL.Image.Image)): + raise ValueError( + "Input image must be a `torch.Tensor` or `PIL.Image.Image`." + ) + + def encode_image(self, image, device, num_meshes_per_prompt): + dtype = next(self.visual_encoder.parameters()).dtype + + image_embeds = self.visual_encoder.encode_image(image) + image_embeds = image_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) + + uncond_image_embeds = self.visual_encoder.empty_image_embeds.repeat( + image_embeds.shape[0], 1, 1 + ).to(image_embeds) + + return image_embeds, uncond_image_embeds + + def encode_caption(self, caption, device, num_meshes_per_prompt): + dtype = next(self.label_encoder.parameters()).dtype + + caption_embeds = self.caption_encoder.encode_text([caption]) + caption_embeds = caption_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) + + uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat( + caption_embeds.shape[0], 1, 1 + ).to(caption_embeds) + + return caption_embeds, uncond_caption_embeds + + def encode_label(self, label, device, num_meshes_per_prompt): + dtype = next(self.label_encoder.parameters()).dtype + + label_embeds = self.label_encoder.encode_label([label]) + label_embeds = label_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) + + uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat( + label_embeds.shape[0], 1, 1 + ).to(label_embeds) + + return label_embeds, uncond_label_embeds + + def prepare_latents( + self, + batch_size, + num_tokens, + num_channels_latents, + dtype, + device, + generator, + latents: Optional[torch.Tensor] = None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = (batch_size, num_tokens, num_channels_latents) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[torch.FloatTensor, PIL.Image.Image, str], + label: Optional[str] = None, + caption: Optional[str] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + num_meshes_per_prompt: int = 1, + guidance_scale: float = 7.5, + generator: Optional[int] = None, + latents: Optional[torch.FloatTensor] = None, + force_remove_background: bool = False, + background_color: List[int] = [255, 255, 255], + foreground_ratio: float = 0.95, + surface_extractor_type: Optional[str] = None, + bounds: float = 1.05, + mc_level: float = 0.0, + octree_resolution: int = 384, + output_type: str = "trimesh", + do_remove_floater: bool = True, + do_remove_degenerate_face: bool = False, + do_reduce_face: bool = True, + do_shade_smooth: bool = True, + max_facenum: int = 200000, + return_dict: bool = True, + use_zero_init: Optional[bool] = True, + zero_steps: Optional[int] = 0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image` or `str`): + `Image`, or tensor representing an image batch, or path to an image file. The image will be encoded to + its CLIP/DINO-v2 embedding which the DiT will be conditioned on. + label (`str`): + The label of the generated mesh, like {"symmetry": "asymmetry", "edge_type": "smooth"} + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality mesh at the expense + of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not provided, will use equally spaced timesteps. + num_meshes_per_prompt (`int`, *optional*, defaults to 1): + The number of meshes to generate per input image. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + Higher guidance scale encourages generation that closely matches the input image. + generator (`int`, *optional*): + A seed to make the generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents to use as inputs for mesh generation. + force_remove_background (`bool`, *optional*, defaults to `False`): + Whether to force remove the background from the input image before processing. + background_color (`List[int]`, *optional*, defaults to `[255, 255, 255]`): + RGB color values for the background if it needs to be removed or modified. + foreground_ratio (`float`, *optional*, defaults to 0.95): + Ratio of the image to consider as foreground when processing. + surface_extractor_type (`str`, *optional*, defaults to "mc"): + Type of surface extraction method to use ("mc" for Marching Cubes or other available methods). + bounds (`float`, *optional*, defaults to 1.05): + Bounding box size for the generated mesh. + mc_level (`float`, *optional*, defaults to 0.0): + Iso-surface level value for Marching Cubes extraction. + octree_resolution (`int`, *optional*, defaults to 256): + Resolution of the octree used for mesh generation. + output_type (`str`, *optional*, defaults to "trimesh"): + Type of output mesh format ("trimesh" or other supported formats). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a `MeshPipelineOutput` instead of a plain tuple. + + Returns: + [`MeshPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`MeshPipelineOutput`] is returned, otherwise a `tuple` is returned where the + first element is a list of generated meshes and the second element is a list of corresponding metadata. + """ + # 0. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + ) + device = self._execution_device + self._guidance_scale = guidance_scale + + # 1. Define call parameters + if isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image) or isinstance(image, str): + batch_size = 1 + + # 2. Preprocess input image + if isinstance(image, torch.Tensor): + assert image.ndim == 3 # H, W, 3 + image_pil = TF.to_pil_image(image) + elif isinstance(image, PIL.Image.Image): + image_pil = image + elif isinstance(image, str): + if image.startswith("http"): + import requests + + image_pil = PIL.Image.open(requests.get(image, stream=True).raw) + else: + image_pil = PIL.Image.open(image) + image_pil = preprocess_image(image_pil, force=force_remove_background, background_color=background_color, foreground_ratio=foreground_ratio) # remove the background images + + # 3. Encode condition + image_embeds, negative_image_embeds = self.encode_image( + image_pil, device, num_meshes_per_prompt + ) + if self.do_classifier_free_guidance and image_embeds is not None: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + # 3.1 Encode label condition + label_embeds = None + if self.transformer.cfg.use_label_condition: + if label is not None: + label_embeds, negative_label_embeds = self.encode_label( + label, device, num_meshes_per_prompt + ) + if self.do_classifier_free_guidance: + label_embeds = torch.cat( + [negative_label_embeds, label_embeds], dim=0 + ) + else: + uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat( + num_meshes_per_prompt, 1, 1 + ).to(image_embeds) + if self.do_classifier_free_guidance: + label_embeds = torch.cat( + [uncond_label_embeds, uncond_label_embeds], dim=0 + ) + # 3.3 Encode caption condition + caption_embeds = None + if self.transformer.cfg.use_caption_condition: + if caption is not None: + caption_embeds, negative_caption_embeds = self.encode_caption( + caption, device, num_meshes_per_prompt + ) + if self.do_classifier_free_guidance: + caption_embeds = torch.cat( + [negative_caption_embeds, caption_embeds], dim=0 + ) + else: + uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat( + num_meshes_per_prompt, 1, 1 + ).to(image_embeds) + if self.do_classifier_free_guidance: + caption_embeds = torch.cat( + [uncond_caption_embeds, uncond_caption_embeds], dim=0 + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_latents = self.vae.cfg.num_latents + num_channels_latents = self.transformer.cfg.input_channels + latents = self.prepare_latents( + batch_size * num_meshes_per_prompt, + num_latents, + num_channels_latents, + image_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + latent_model_input, + timestep, + visual_condition=image_embeds, + label_condition=label_embeds, + caption_condition=caption_embeds, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_image = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_image - noise_pred_uncond + ) + + if (i <= zero_steps) and use_zero_init: + noise_pred = noise_pred * 0.0 + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + # 4. Post-processing + if not output_type == "latent": + if latents.dtype == torch.bfloat16: + self.vae.to(torch.float16) + latents = latents.to(torch.float16) + mesh = self.vae.extract_geometry( + self.vae.decode(latents), + surface_extractor_type=surface_extractor_type, + bounds=bounds, + mc_level=mc_level, + octree_resolution=octree_resolution, + enable_pbar=False, + ) + if output_type != "raw": + mesh_list = [] + for i, cur_mesh in enumerate(mesh): + print(f"Generating mesh {i+1}/{num_meshes_per_prompt}") + if output_type == "trimesh": + import trimesh + + cur_mesh = trimesh.Trimesh( + vertices=cur_mesh.verts.cpu().numpy(), + faces=cur_mesh.faces.cpu().numpy(), + ) + cur_mesh.fix_normals() + cur_mesh.face_normals + cur_mesh.vertex_normals + cur_mesh.visual = trimesh.visual.TextureVisuals( + material=trimesh.visual.material.PBRMaterial( + baseColorFactor=(255, 255, 255), + main_color=(255, 255, 255), + metallicFactor=0.05, + roughnessFactor=1.0, + ) + ) + if do_remove_floater: + cur_mesh = remove_floater(cur_mesh) + if do_remove_degenerate_face: + cur_mesh = remove_degenerate_face(cur_mesh) + if do_reduce_face and max_facenum > 0: + cur_mesh = reduce_face(cur_mesh, max_facenum) + if do_shade_smooth: + cur_mesh = cur_mesh.smooth_shaded + mesh_list.append(cur_mesh) + elif output_type == "np": + if do_remove_floater: + print( + 'remove floater is NOT used when output_type is "np". ' + ) + if do_remove_degenerate_face: + print( + 'remove degenerate face is NOT used when output_type is "np". ' + ) + if do_reduce_face: + print( + 'reduce floater is NOT used when output_type is "np". ' + ) + if do_shade_smooth: + print('shade smooth is NOT used when output_type is "np". ') + mesh_list.append( + [ + cur_mesh[0].verts.cpu().numpy(), + cur_mesh[0].faces.cpu().numpy(), + ] + ) + mesh = mesh_list + else: + if do_remove_floater: + print('remove floater is NOT used when output_type is "raw". ') + if do_remove_degenerate_face: + print( + 'remove degenerate face is NOT used when output_type is "raw". ' + ) + if do_reduce_face: + print('reduce floater is NOT used when output_type is "raw". ') + + else: + mesh = latents + + if not return_dict: + return tuple(image_pil), tuple(mesh) + return Step1X3DGeometryPipelineOutput(image=image_pil, mesh=mesh) diff --git a/step1x3d_geometry/models/pipelines/pipeline_utils.py b/step1x3d_geometry/models/pipelines/pipeline_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..020d7002b04400243ebd1e1206b7db5475e88f0d --- /dev/null +++ b/step1x3d_geometry/models/pipelines/pipeline_utils.py @@ -0,0 +1,404 @@ +from typing import Callable, List, Optional, Union, Dict, Any +import os +from diffusers.utils import logging +import PIL.Image +import torch +import trimesh +import pymeshlab +import tempfile +from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult + +logger = logging.get_logger(__name__) + + +def preprocess_image( + images_pil: Union[List[PIL.Image.Image], PIL.Image.Image], + force: bool = False, + background_color: List[int] = [255, 255, 255], + foreground_ratio: float = 0.9, + rembg_backend: str = "bria", + **rembg_kwargs, +): + r""" + Crop and remote the background of the input image + Args: + image_pil (`List[PIL.Image.Image]`): + List of `PIL.Image.Image` objects representing the input image. + force (`bool`, *optional*, defaults to `False`): + Whether to force remove the background even if the image has an alpha channel. + Returns: + `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image. + """ + is_single_image = False + if isinstance(images_pil, PIL.Image.Image): + images_pil = [images_pil] + is_single_image = True + preprocessed_images = [] + for i in range(len(images_pil)): + image = images_pil[i] + width, height, size = image.width, image.height, image.size + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print( + "alhpa channl not empty, skip remove background, using alpha channel as mask" + ) + do_remove = False + do_remove = do_remove or force + if do_remove: + import rembg # lazy import + + if rembg_backend == "default": + image = rembg.remove(image, **rembg_kwargs) + else: + image = rembg.remove( + image, + session=rembg.new_session( + model_name="bria", + providers=[ + ( + "CUDAExecutionProvider", + { + "device_id": 0, + "arena_extend_strategy": "kSameAsRequested", + "gpu_mem_limit": 6 * 1024 * 1024 * 1024, + "cudnn_conv_algo_search": "HEURISTIC", + }, + ), + "CPUExecutionProvider", + ], + ), + **rembg_kwargs, + ) + + # calculate the min bbox of the image + alpha = image.split()[-1] + bboxs = alpha.getbbox() + x1, y1, x2, y2 = bboxs + dy, dx = y2 - y1, x2 - x1 + s = min(height * foreground_ratio / dy, width * foreground_ratio / dx) + Ht, Wt = int(dy * s), int(dx * s) + + background = PIL.Image.new("RGBA", image.size, (*background_color, 255)) + image = PIL.Image.alpha_composite(background, image) + image = image.crop(alpha.getbbox()) + alpha = alpha.crop(alpha.getbbox()) + + # Calculate the new size after rescaling + new_size = tuple(int(dim * foreground_ratio) for dim in size) + # Resize the image while maintaining the aspect ratio + resized_image = image.resize((Wt, Ht)) + resized_alpha = alpha.resize((Wt, Ht)) + # Create a new image with the original size and white background + padded_image = PIL.Image.new("RGB", size, tuple(background_color)) + padded_alpha = PIL.Image.new("L", size, (0)) + paste_position = ( + (width - resized_image.width) // 2, + (height - resized_image.height) // 2, + ) + padded_image.paste(resized_image, paste_position) + padded_alpha.paste(resized_alpha, paste_position) + + # expand image to 1:1 + width, height = padded_image.size + if width == height: + padded_image.putalpha(padded_alpha) + preprocessed_images.append(padded_image) + continue + new_size = (max(width, height), max(width, height)) + new_image = PIL.Image.new("RGB", new_size, tuple(background_color)) + new_alpha = PIL.Image.new("L", new_size, (0)) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + new_image.paste(padded_image, paste_position) + new_alpha.paste(padded_alpha, paste_position) + new_image.putalpha(new_alpha) + preprocessed_images.append(new_image) + + if is_single_image: + return preprocessed_images[0] + return preprocessed_images + + +def load_mesh(path): + if path.endswith(".glb"): + mesh = trimesh.load(path) + else: + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(path) + return mesh + + +def trimesh2pymeshlab(mesh: trimesh.Trimesh): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + if isinstance(mesh, trimesh.scene.Scene): + for idx, obj in enumerate(mesh.geometry.values()): + if idx == 0: + temp_mesh = obj + else: + temp_mesh = temp_mesh + obj + mesh = temp_mesh + mesh.export(temp_file.name) + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(temp_file.name) + return mesh + + +def pymeshlab2trimesh(mesh: pymeshlab.MeshSet): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + mesh.save_current_mesh(temp_file.name) + mesh = trimesh.load(temp_file.name) + if isinstance(mesh, trimesh.Scene): + combined_mesh = trimesh.Trimesh() + for geom in mesh.geometry.values(): + combined_mesh = trimesh.util.concatenate([combined_mesh, geom]) + mesh = combined_mesh + return mesh + + +def import_mesh(mesh): + mesh_type = type(mesh) + if isinstance(mesh, str): + mesh = load_mesh(mesh) + elif isinstance(mesh, MeshExtractResult): + mesh = pymeshlab.MeshSet() + mesh_pymeshlab = pymeshlab.Mesh( + vertex_matrix=mesh.verts.cpu().numpy(), face_matrix=mesh.faces.cpu().numpy() + ) + mesh.add_mesh(mesh_pymeshlab, "converted_mesh") + + if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)): + mesh = trimesh2pymeshlab(mesh) + + return mesh, mesh_type + + +def remove_floater(mesh): + mesh, mesh_type = import_mesh(mesh) + + mesh.apply_filter( + "compute_selection_by_small_disconnected_components_per_face", nbfaceratio=0.001 + ) + mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False) + mesh.apply_filter("meshing_remove_selected_vertices_and_faces") + + return pymeshlab2trimesh(mesh) + + +def remove_degenerate_face(mesh): + mesh, mesh_type = import_mesh(mesh) + + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + mesh.save_current_mesh(temp_file.name) + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(temp_file.name) + + return pymeshlab2trimesh(mesh) + + +def reduce_face(mesh, max_facenum=50000): + mesh, mesh_type = import_mesh(mesh) + + if max_facenum > mesh.current_mesh().face_number(): + return pymeshlab2trimesh(mesh) + + mesh.apply_filter( + "meshing_decimation_quadric_edge_collapse", + targetfacenum=max_facenum, + qualitythr=1.0, + preserveboundary=True, + boundaryweight=3, + preservenormal=True, + preservetopology=True, + autoclean=True, + ) + + return pymeshlab2trimesh(mesh) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class TransformerDiffusionMixin: + r""" + Helper for DiffusionPipeline with vae and transformer.(mainly for DIT) + """ + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_transformer = False + self.fusing_vae = False + + if transformer: + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + if vae: + self.fusing_vae = True + self.vae.fuse_qkv_projections() + + def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if transformer: + if not self.fusing_transformer: + logger.warning( + "The UNet was not initially fused for QKV projections. Doing nothing." + ) + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + if vae: + if not self.fusing_vae: + logger.warning( + "The VAE was not initially fused for QKV projections. Doing nothing." + ) + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + +def try_download(model_id, subfolder): + try: + from huggingface_hub import snapshot_download + + path = snapshot_download( + repo_id=model_id, + allow_patterns=[f"{subfolder}/*"], + ) + print(path) + model_path = os.path.join(path, subfolder) + return model_path + except Exception as e: + raise e + + +def smart_load_model(model_path, subfolder = ""): + if subfolder == "": + if os.path.exists(model_path): + return model_path + else: + return try_download(model_path, '.') + else: + if os.path.exists(os.path.join(model_path, subfolder)): + return os.path.join(model_path, subfolder) + else: + return try_download(model_path, subfolder) + + + diff --git a/step1x3d_geometry/models/transformers/__init__.py b/step1x3d_geometry/models/transformers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..63dfb03e5239b67b3160c2beea123faf1d44a7dc --- /dev/null +++ b/step1x3d_geometry/models/transformers/__init__.py @@ -0,0 +1 @@ +from . import flux_transformer_1d, pixart_transformer_1d diff --git a/step1x3d_geometry/models/transformers/flux_transformer_1d.py b/step1x3d_geometry/models/transformers/flux_transformer_1d.py new file mode 100755 index 0000000000000000000000000000000000000000..642949388cca866fb26f6f439e461c5250e4d4ca --- /dev/null +++ b/step1x3d_geometry/models/transformers/flux_transformer_1d.py @@ -0,0 +1,600 @@ +# Some parts of this file are adapted from Hugging Face Diffusers library. +from typing import Any, Dict, Optional, Union, Tuple +from dataclasses import dataclass + +import re +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, + AttnProcessor, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import ( + GaussianFourierProjection, + TimestepEmbedding, + Timesteps, +) +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.models.normalization import ( + AdaLayerNormSingle, + AdaLayerNormContinuous, + FP32LayerNorm, + LayerNorm, +) + +from ..attention_processor import FusedFluxAttnProcessor2_0, FluxAttnProcessor2_0 +from ..attention import FluxTransformerBlock, FluxSingleTransformerBlock + +import step1x3d_geometry +from step1x3d_geometry.utils.base import BaseModule + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Transformer1DModelOutput: + sample: torch.FloatTensor + + +class FluxTransformer1DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-la + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + width (`int`, *optional*, defaults to 2048): + Maximum sequence length in latent space (equivalent to max_seq_length in Transformers). + Determines the first dimension size of positional embedding matrices[1](@ref). + in_channels (`int`, *optional*, defaults to 64): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): + Dimensionality of conditional embeddings for cross-attention mechanisms + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + width: int = 2048, + in_channels: int = 4, + num_layers: int = 19, + num_single_layers: int = 38, + cross_attention_dim: int = 768, + ): + super().__init__() + # Set some common variables used across the board. + self.out_channels = in_channels + self.num_heads = num_attention_heads + self.inner_dim = width + + # self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + # self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + time_embed_dim, timestep_input_dim = self._set_time_proj( + "positional", + inner_dim=self.inner_dim, + flip_sin_to_cos=False, + freq_shift=0, + time_embedding_dim=None, + ) + self.time_proj = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim + ) + self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True) + self.proj_cross_attention = nn.Linear( + self.config.cross_attention_dim, self.inner_dim, bias=True + ) + + # 2. Initialize the transformer blocks. + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=width // num_attention_heads, + ) + for _ in range(self.config.num_layers) + ] + ) + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=width // num_attention_heads, + ) + for _ in range(self.config.num_single_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_time_proj( + self, + time_embedding_type: str, + inner_dim: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or inner_dim * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_embed = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or inner_dim * 4 + + self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + timestep_input_dim = inner_dim + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedFluxAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(FluxAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking( + self, chunk_size: Optional[int] = None, dim: int = 0 + ) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + def forward( + self, + hidden_states: Optional[torch.Tensor], + timestep: Union[int, float, torch.LongTensor], + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, latents_size)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + return_dict: bool + Whether to return a dictionary. + """ + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + _, N, _ = hidden_states.shape + + # import pdb; pdb.set_trace() + # timesteps_proj = self.time_proj(timestep) # N x 256 + # temb = self.time_embed(timesteps_proj).to(hidden_states.dtype) + temb = self.time_embed(timestep).to(hidden_states.dtype) # N x 1280 + temb = self.time_proj(temb) # N x 1280 + + hidden_states = self.proj_in(hidden_states) + encoder_hidden_states = self.proj_cross_attention(encoder_hidden_states) + + for layer, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states = ( + torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, # image_rotary_emb + attention_kwargs, + ) + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=None, + joint_attention_kwargs=attention_kwargs, + ) # (N, L, D) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for layer, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + None, # image_rotary_emb + attention_kwargs, + ) + else: + hidden_states = block( + hidden_states, + temb=temb, + image_rotary_emb=None, + joint_attention_kwargs=attention_kwargs, + ) # (N, L, D) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + # final layer + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer1DModelOutput(sample=hidden_states) + + +@step1x3d_geometry.register("flux-denoiser") +class FluxDenoiser(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = None + input_channels: int = 32 + width: int = 768 + layers: int = 12 + num_single_layers: int = 12 + num_heads: int = 16 + condition_dim: int = 1024 + multi_condition_type: str = "in_context" + use_visual_condition: bool = False + visual_condition_dim: int = 1024 + n_views: int = 1 + use_caption_condition: bool = False + caption_condition_dim: int = 1024 + use_label_condition: bool = False + label_condition_dim: int = 1024 + + identity_init: bool = False + + cfg: Config + + def configure(self) -> None: + assert ( + self.cfg.multi_condition_type == "in_context" + ), "Flux Denoiser only support in_context learning of multiple conditions" + self.dit_model = FluxTransformer1DModel( + num_attention_heads=self.cfg.num_heads, + width=self.cfg.width, + in_channels=self.cfg.input_channels, + num_layers=self.cfg.layers, + num_single_layers=self.cfg.num_single_layers, + cross_attention_dim=self.cfg.condition_dim, + ) + if ( + self.cfg.use_visual_condition + and self.cfg.visual_condition_dim != self.cfg.condition_dim + ): + self.proj_visual_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.visual_condition_dim), + nn.Linear(self.cfg.visual_condition_dim, self.cfg.condition_dim), + ) + if ( + self.cfg.use_caption_condition + and self.cfg.caption_condition_dim != self.cfg.condition_dim + ): + self.proj_caption_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.caption_condition_dim), + nn.Linear(self.cfg.caption_condition_dim, self.cfg.condition_dim), + ) + if ( + self.cfg.use_label_condition + and self.cfg.label_condition_dim != self.cfg.condition_dim + ): + self.proj_label_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.label_condition_dim), + nn.Linear(self.cfg.label_condition_dim, self.cfg.condition_dim), + ) + + if self.cfg.identity_init: + self.identity_initialize() + + if self.cfg.pretrained_model_name_or_path: + print( + f"Loading pretrained DiT model from {self.cfg.pretrained_model_name_or_path}" + ) + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, + map_location="cpu", + weights_only=True, + ) + if "state_dict" in ckpt.keys(): + ckpt = ckpt["state_dict"] + + self.load_state_dict(ckpt, strict=True) + + def identity_initialize(self): + for block in self.dit_model.blocks: + nn.init.constant_(block.attn.c_proj.weight, 0) + nn.init.constant_(block.attn.c_proj.bias, 0) + nn.init.constant_(block.cross_attn.c_proj.weight, 0) + nn.init.constant_(block.cross_attn.c_proj.bias, 0) + nn.init.constant_(block.mlp.c_proj.weight, 0) + nn.init.constant_(block.mlp.c_proj.bias, 0) + + def forward( + self, + model_input: torch.FloatTensor, + timestep: torch.LongTensor, + visual_condition: Optional[torch.FloatTensor] = None, + caption_condition: Optional[torch.FloatTensor] = None, + label_condition: Optional[torch.FloatTensor] = None, + attention_kwargs: Dict[str, torch.Tensor] = None, + return_dict: bool = True, + ): + r""" + Args: + model_input (torch.FloatTensor): [bs, n_data, c] + timestep (torch.LongTensor): [bs,] + visual_condition (torch.FloatTensor): [bs, visual_context_tokens, c] + caption_condition (torch.FloatTensor): [bs, text_context_tokens, c] + label_condition (torch.FloatTensor): [bs, c] + + Returns: + sample (torch.FloatTensor): [bs, n_data, c] + + """ + + B, n_data, _ = model_input.shape + + # 0. conditions projector + condition = [] + if self.cfg.use_visual_condition: + assert visual_condition.shape[-1] == self.cfg.visual_condition_dim + if self.cfg.visual_condition_dim != self.cfg.condition_dim: + visual_condition = self.proj_visual_condtion(visual_condition) + condition.append(visual_condition) + if self.cfg.use_caption_condition: + assert caption_condition.shape[-1] == self.cfg.caption_condition_dim + if self.cfg.caption_condition_dim != self.cfg.condition_dim: + caption_condition = self.proj_caption_condtion(caption_condition) + condition.append(caption_condition) + if self.cfg.use_label_condition: + assert label_condition.shape[-1] == self.cfg.label_condition_dim + if self.cfg.label_condition_dim != self.cfg.condition_dim: + label_condition = self.proj_label_condtion(label_condition) + condition.append(label_condition) + + # 1. denoise + output = self.dit_model( + model_input, + timestep, + torch.cat(condition, dim=1), + attention_kwargs, + return_dict=return_dict, + ) + + return output diff --git a/step1x3d_geometry/models/transformers/pixart_transformer_1d.py b/step1x3d_geometry/models/transformers/pixart_transformer_1d.py new file mode 100755 index 0000000000000000000000000000000000000000..0dfccbf757b6f504a3118fc76ee16281c7fea688 --- /dev/null +++ b/step1x3d_geometry/models/transformers/pixart_transformer_1d.py @@ -0,0 +1,574 @@ +# Some parts of this file are adapted from Hugging Face Diffusers library. +from dataclasses import dataclass + +import re +import math +import torch +from torch import nn +from typing import Callable, List, Optional, Union, Dict, Any +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import logging +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + +from ..attention_processor import FusedAttnProcessor2_0, AttnProcessor2_0 +from ..attention import MultiCondBasicTransformerBlock + +import step1x3d_geometry +from step1x3d_geometry.utils.base import BaseModule + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Transformer1DModelOutput: + sample: torch.FloatTensor + + +class PixArtTransformer1DModel(ModelMixin, ConfigMixin): + r""" + A 1D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, + https://arxiv.org/abs/2403.04692). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + width (`int`, *optional*, defaults to 2048): + Maximum sequence length in latent space (equivalent to max_seq_length in Transformers). + Determines the first dimension size of positional embedding matrices[1](@ref). + in_channels (`int`, *optional*, defaults to 64): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): + Dimensionality of conditional embeddings for cross-attention mechanisms + use_cross_attention_2 (`bool`, *optional*): + Flag to enable secondary cross-attention mechanism. Used for multi-modal conditioning + when processing hybrid inputs (e.g., text + image prompts)[1](@ref). + cross_attention_2_dim (`int`, *optional*, defaults to 1024): + Dimensionality of secondary cross-attention embeddings. Specifies encoding dimensions + for additional conditional modalities when use_cross_attention_2 is enabled[1](@ref). + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["MultiCondBasicTransformerBlock", "PatchEmbed"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + width: int = 2048, + in_channels: int = 4, + num_layers: int = 28, + cross_attention_dim: int = 768, + use_cross_attention_2: bool = True, + cross_attention_2_dim: int = 1024, + use_cross_attention_3: bool = True, + cross_attention_3_dim: int = 1024, + ): + super().__init__() + # Set some common variables used across the board. + self.out_channels = in_channels + self.num_heads = num_attention_heads + self.inner_dim = width + + self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True) + + # 2. Initialize the transformer blocks. + self.transformer_blocks = nn.ModuleList( + [ + MultiCondBasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + use_self_attention=True, + use_cross_attention=True, + self_attention_norm_type="ada_norm_single", + cross_attention_dim=self.config.cross_attention_dim, + cross_attention_norm_type="ada_norm_single", + use_cross_attention_2=self.config.use_cross_attention_2, + cross_attention_2_dim=self.config.cross_attention_2_dim, + cross_attention_2_norm_type="ada_norm_single", + use_cross_attention_3=self.config.use_cross_attention_3, + cross_attention_3_dim=self.config.cross_attention_3_dim, + cross_attention_3_norm_type="ada_norm_single", + dropout=0.0, + attention_bias=False, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_elementwise_affine=True, + upcast_attention=False, + norm_eps=1e-6, + attention_type="default", + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.RMSNorm(self.inner_dim, elementwise_affine=True, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, self.inner_dim) / self.inner_dim**0.5 + ) + self.proj_out = nn.Linear(self.inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=None + ) + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + encoder_hidden_states_3: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask_2: Optional[torch.Tensor] = None, + encoder_attention_mask_3: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`PixArtTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, n_tokens)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + encoder_hidden_states_2 (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + encoder_hidden_states_3 (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (`torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + encoder_attention_mask_2 ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states_2`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + encoder_attention_mask_3 ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states_3`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~Transformer1DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # convert encoder_attention_mask_2 to a bias the same way we do for attention_mask + if encoder_attention_mask_2 is not None and encoder_attention_mask_2.ndim == 2: + encoder_attention_mask_2 = ( + 1 - encoder_attention_mask_2.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask_2 = encoder_attention_mask_2.unsqueeze(1) + + # convert encoder_attention_mask_2 to a bias the same way we do for attention_mask + if encoder_attention_mask_3 is not None and encoder_attention_mask_3.ndim == 2: + encoder_attention_mask_3 = ( + 1 - encoder_attention_mask_3.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask_3 = encoder_attention_mask_3.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_hidden_states_2, + encoder_hidden_states_3, + encoder_attention_mask, + encoder_attention_mask_2, + encoder_attention_mask_3, + timestep, + cross_attention_kwargs, + None, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_2=encoder_hidden_states_2, + encoder_hidden_states_3=encoder_hidden_states_3, + encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask_2=encoder_attention_mask_2, + encoder_attention_mask_3=encoder_attention_mask_3, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.scale_shift_table[None] + + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to( + hidden_states.device + ) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + if not return_dict: + return (hidden_states,) + + return Transformer1DModelOutput(sample=hidden_states) + + +@step1x3d_geometry.register("pixart-denoiser") +class PixArtDenoiser(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = None + input_channels: int = 32 + width: int = 768 + layers: int = 28 + num_heads: int = 16 + condition_dim: int = 1024 + multi_condition_type: str = "cross_attention" + use_visual_condition: bool = False + visual_condition_dim: int = 1024 + n_views: int = 1 # for multi-view condition + use_caption_condition: bool = False + caption_condition_dim: int = 1024 + use_label_condition: bool = False + label_condition_dim: int = 1024 + + identity_init: bool = False + + cfg: Config + + def configure(self) -> None: + self.dit_model = PixArtTransformer1DModel( + num_attention_heads=self.cfg.num_heads, + width=self.cfg.width, + in_channels=self.cfg.input_channels, + num_layers=self.cfg.layers, + cross_attention_dim=self.cfg.condition_dim, + use_cross_attention_2=self.cfg.use_caption_condition + and self.cfg.multi_condition_type == "cross_attention", + cross_attention_2_dim=self.cfg.condition_dim, + use_cross_attention_3=self.cfg.use_label_condition + and self.cfg.multi_condition_type == "cross_attention", + cross_attention_3_dim=self.cfg.condition_dim, + ) + if ( + self.cfg.use_visual_condition + and self.cfg.visual_condition_dim != self.cfg.condition_dim + ): + self.proj_visual_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.visual_condition_dim), + nn.Linear(self.cfg.visual_condition_dim, self.cfg.condition_dim), + ) + if ( + self.cfg.use_caption_condition + and self.cfg.caption_condition_dim != self.cfg.condition_dim + ): + self.proj_caption_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.caption_condition_dim), + nn.Linear(self.cfg.caption_condition_dim, self.cfg.condition_dim), + ) + if ( + self.cfg.use_label_condition + and self.cfg.label_condition_dim != self.cfg.condition_dim + ): + self.proj_label_condtion = nn.Sequential( + nn.RMSNorm(self.cfg.label_condition_dim), + nn.Linear(self.cfg.label_condition_dim, self.cfg.condition_dim), + ) + + if self.cfg.identity_init: + self.identity_initialize() + + if self.cfg.pretrained_model_name_or_path: + print( + f"Loading pretrained DiT model from {self.cfg.pretrained_model_name_or_path}" + ) + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, + map_location="cpu", + weights_only=False, + ) + if "state_dict" in ckpt.keys(): + ckpt = ckpt["state_dict"] + self.load_state_dict(ckpt, strict=True) + + def identity_initialize(self): + for block in self.dit_model.blocks: + nn.init.constant_(block.attn.c_proj.weight, 0) + nn.init.constant_(block.attn.c_proj.bias, 0) + nn.init.constant_(block.cross_attn.c_proj.weight, 0) + nn.init.constant_(block.cross_attn.c_proj.bias, 0) + nn.init.constant_(block.mlp.c_proj.weight, 0) + nn.init.constant_(block.mlp.c_proj.bias, 0) + + def forward( + self, + model_input: torch.FloatTensor, + timestep: torch.LongTensor, + visual_condition: Optional[torch.FloatTensor] = None, + caption_condition: Optional[torch.FloatTensor] = None, + label_condition: Optional[torch.FloatTensor] = None, + attention_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict: bool = True, + ): + r""" + Args: + model_input (torch.FloatTensor): [bs, n_data, c] + timestep (torch.LongTensor): [bs,] + visual_condition (torch.FloatTensor): [bs, visual_context_tokens, c] + text_condition (torch.FloatTensor): [bs, text_context_tokens, c] + + Returns: + sample (torch.FloatTensor): [bs, n_data, c] + + """ + + B, n_data, _ = model_input.shape + + # 0. conditions projector + condition = [] + if self.cfg.use_visual_condition: + assert visual_condition.shape[-1] == self.cfg.visual_condition_dim + if self.cfg.visual_condition_dim != self.cfg.condition_dim: + visual_condition = self.proj_visual_condtion(visual_condition) + condition.append(visual_condition) + else: + visual_condition = None + if self.cfg.use_caption_condition: + assert caption_condition.shape[-1] == self.cfg.caption_condition_dim + if self.cfg.caption_condition_dim != self.cfg.condition_dim: + caption_condition = self.proj_caption_condtion(caption_condition) + condition.append(caption_condition) + else: + caption_condition = None + if self.cfg.use_label_condition: + assert label_condition.shape[-1] == self.cfg.label_condition_dim + if self.cfg.label_condition_dim != self.cfg.condition_dim: + label_condition = self.proj_label_condtion(label_condition) + condition.append(label_condition) + else: + label_condition = None + assert not ( + visual_condition is None + and caption_condition is None + and label_condition is None + ) + + # 1. denoise + if self.cfg.multi_condition_type == "cross_attention": + output = self.dit_model( + model_input, + timestep, + visual_condition, + caption_condition, + label_condition, + cross_attention_kwargs, + return_dict=return_dict, + ) + elif self.cfg.multi_condition_type == "in_context": + output = self.dit_model( + model_input, + timestep, + torch.cat(condition, dim=1), + None, + None, + cross_attention_kwargs, + return_dict=return_dict, + ) + else: + raise ValueError + + return output diff --git a/step1x3d_geometry/systems/__init__.py b/step1x3d_geometry/systems/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..62bfba8d0c24e177deffb43fea4f814545340065 --- /dev/null +++ b/step1x3d_geometry/systems/__init__.py @@ -0,0 +1 @@ +from . import shape_autoencoder, shape_diffusion, shape_rectified_flow diff --git a/step1x3d_geometry/systems/base.py b/step1x3d_geometry/systems/base.py new file mode 100755 index 0000000000000000000000000000000000000000..afaa28a5d1e25f429fc039994340113ea4afb6a1 --- /dev/null +++ b/step1x3d_geometry/systems/base.py @@ -0,0 +1,210 @@ +import os +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch.nn.functional as F + +import step1x3d_geometry +from step1x3d_geometry.utils.base import ( + Updateable, + update_end_if_possible, + update_if_possible, +) +from step1x3d_geometry.utils.scheduler import parse_optimizer, parse_scheduler +from step1x3d_geometry.utils.config import parse_structured +from step1x3d_geometry.utils.misc import C, cleanup, get_device, load_module_weights +from step1x3d_geometry.utils.saving import SaverMixin +from step1x3d_geometry.utils.typing import * + + +class BaseSystem(pl.LightningModule, Updateable, SaverMixin): + @dataclass + class Config: + loggers: dict = field(default_factory=dict) + loss: dict = field(default_factory=dict) + optimizer: dict = field(default_factory=dict) + scheduler: Optional[dict] = None + weights: Optional[str] = None + weights_ignore_modules: Optional[List[str]] = None + cleanup_after_validation_step: bool = False + cleanup_after_test_step: bool = False + + pretrained_model_path: Optional[str] = None + strict_load: bool = True + + cfg: Config + + def __init__(self, cfg, resumed=False) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self._save_dir: Optional[str] = None + self._resumed: bool = resumed + self._resumed_eval: bool = False + self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} + if "loggers" in cfg: + self.create_loggers(cfg.loggers) + + self.configure() + if self.cfg.weights is not None: + self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) + self.post_configure() + + def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): + state_dict, epoch, global_step = load_module_weights( + weights, ignore_modules=ignore_modules, map_location="cpu" + ) + self.load_state_dict(state_dict, strict=False) + # restore step-dependent states + self.do_update_step(epoch, global_step, on_load_weights=True) + + def set_resume_status(self, current_epoch: int, global_step: int): + # restore correct epoch and global step in eval + self._resumed_eval = True + self._resumed_eval_status["current_epoch"] = current_epoch + self._resumed_eval_status["global_step"] = global_step + + @property + def resumed(self): + # whether from resumed checkpoint + return self._resumed + + @property + def true_global_step(self): + if self._resumed_eval: + return self._resumed_eval_status["global_step"] + else: + return self.global_step + + @property + def true_current_epoch(self): + if self._resumed_eval: + return self._resumed_eval_status["current_epoch"] + else: + return self.current_epoch + + def configure(self) -> None: + pass + + def post_configure(self) -> None: + """ + executed after weights are loaded + """ + pass + + def C(self, value: Any) -> float: + return C(value, self.true_current_epoch, self.true_global_step) + + def configure_optimizers(self): + optim = parse_optimizer(self.cfg.optimizer, self) + ret = { + "optimizer": optim, + } + if self.cfg.scheduler is not None: + ret.update( + { + "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), + } + ) + return ret + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + def on_train_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.train_dataloader.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.val_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_validation_step: + # cleanup to save vram + cleanup() + + def on_validation_epoch_end(self): + raise NotImplementedError + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def on_test_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.test_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_test_epoch_end(self): + pass + + def predict_step(self, batch, batch_idx): + raise NotImplementedError + + def on_predict_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.predict_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_predict_epoch_end(self): + pass + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.preprocess_data(batch, "train") + self.dataset = self.trainer.train_dataloader.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "validation") + self.dataset = self.trainer.val_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "test") + self.dataset = self.trainer.test_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "predict") + self.dataset = self.trainer.predict_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + pass + + def on_before_optimizer_step(self, optimizer): + """ + # some gradient-related debugging goes here, example: + from lightning.pytorch.utilities import grad_norm + norms = grad_norm(self.geometry, norm_type=2) + print(norms) + """ + pass diff --git a/step1x3d_geometry/systems/shape_autoencoder.py b/step1x3d_geometry/systems/shape_autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..84fbd721c05f3978571ab7f8489c6bb387b3179a --- /dev/null +++ b/step1x3d_geometry/systems/shape_autoencoder.py @@ -0,0 +1,151 @@ +from dataclasses import dataclass, field +import numpy as np +import torch +from skimage import measure +from einops import repeat, rearrange + +import step1x3d_geometry +from step1x3d_geometry.systems.base import BaseSystem +from step1x3d_geometry.utils.ops import generate_dense_grid_points +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.misc import get_rank + + +@step1x3d_geometry.register("shape-autoencoder-system") +class ShapeAutoEncoderSystem(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + shape_model_type: str = None + shape_model: dict = field(default_factory=dict) + + sample_posterior: bool = True + + # for mesh extraction + bounds: float = 1.05 + mc_level: float = 0.0 + octree_resolution: int = 256 + + cfg: Config + + def configure(self): + super().configure() + + self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)( + self.cfg.shape_model + ) + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + rand_points = batch["rand_points"] + if "sdf" in batch: + target = batch["sdf"] + criteria = torch.nn.MSELoss() + elif "occupancies" in batch: + target = batch["occupancies"] + criteria = torch.nn.BCEWithLogitsLoss() + else: + raise NotImplementedError + + # forward pass + num_point_feats = 3 + self.cfg.shape_model.point_feats + shape_latents, kl_embed, posterior = self.shape_model.encode( + batch["surface"][..., :num_point_feats], + sharp_surface=( + batch["sharp_surface"][..., :num_point_feats] + if "sharp_surface" in batch + else None + ), + sample_posterior=self.cfg.sample_posterior, + ) + latents = self.shape_model.decode(kl_embed) # [B, num_latents, width] + logits = self.shape_model.query(rand_points, latents).squeeze( + -1 + ) # [B, num_rand_points] + + if self.cfg.sample_posterior: + loss_kl = posterior.kl() + loss_kl = torch.sum(loss_kl) / loss_kl.shape[0] + + return { + "loss_logits": criteria(logits, target).mean(), + "loss_kl": loss_kl, + "logits": logits, + "target": target, + "latents": latents, + } + else: + return { + "loss_logits": criteria(logits, target).mean(), + "latents": latents, + "logits": logits, + } + + def training_step(self, batch, batch_idx): + """ + Description: + + Args: + batch: + batch_idx: + Returns: + loss: + """ + out = self(batch) + + loss = 0.0 + for name, value in out.items(): + if name.startswith("loss_"): + self.log(f"train/{name}", value) + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + self.eval() + out = self(batch) + + meshes = self.shape_model.extract_geometry( + out["latents"], + bounds=self.cfg.bounds, + mc_level=self.cfg.mc_level, + octree_resolution=self.cfg.octree_resolution, + enable_pbar=False, + ) + for idx, name in enumerate(batch["uid"]): + self.save_mesh( + f"it{self.true_global_step}/{name}.obj", + meshes[idx].verts, + meshes[idx].faces, + ) + + threshold = 0 + outputs = out["logits"] + labels = out["target"] + pred = torch.zeros_like(outputs) + pred[outputs >= threshold] = 1 + + accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1] + accuracy = accuracy.mean() + intersection = (pred * labels).sum(dim=1) + union = (pred + labels).gt(0).sum(dim=1) + iou = intersection * 1.0 / union + 1e-5 + iou = iou.mean() + self.log("val/accuracy", accuracy) + self.log("val/iou", iou) + + torch.cuda.empty_cache() + + return { + "val/loss": out["loss_logits"], + "val/accuracy": accuracy, + "val/iou": iou, + } + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + return diff --git a/step1x3d_geometry/systems/shape_diffusion.py b/step1x3d_geometry/systems/shape_diffusion.py new file mode 100755 index 0000000000000000000000000000000000000000..5fbc9e2952af287930954706003a1321a1ae151d --- /dev/null +++ b/step1x3d_geometry/systems/shape_diffusion.py @@ -0,0 +1,425 @@ +from dataclasses import dataclass, field + +from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline +import numpy as np +import json +import copy +import torch +import torch.nn.functional as F +from skimage import measure +from einops import repeat +from tqdm import tqdm +from PIL import Image + +from diffusers import ( + DDPMScheduler, + DDIMScheduler, + UniPCMultistepScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler, +) +from diffusers.training_utils import ( + compute_snr, + free_memory, +) +import step1x3d_geometry +from step1x3d_geometry.systems.base import BaseSystem +from step1x3d_geometry.utils.misc import get_rank +from step1x3d_geometry.utils.typing import * +from diffusers import DDIMScheduler +from step1x3d_geometry.systems.utils import read_image, ddim_sample + + +# DEBUG = True +@step1x3d_geometry.register("diffusion-system") +class DiffusionSystem(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + val_samples_json: str = "" + bounds: float = 1.05 + mc_level: float = 0.0 + octree_resolution: int = 256 + skip_validation: bool = True + + # diffusion config + z_scale_factor: float = 1.0 + guidance_scale: float = 7.5 + num_inference_steps: int = 50 + eta: float = 0.0 + snr_gamma: float = 5.0 + + # shape vae model + shape_model_type: str = None + shape_model: dict = field(default_factory=dict) + + # condition model + visual_condition_type: Optional[str] = None + visual_condition: dict = field(default_factory=dict) + caption_condition_type: Optional[str] = None + caption_condition: dict = field(default_factory=dict) + label_condition_type: Optional[str] = None + label_condition: dict = field(default_factory=dict) + + # diffusion model + denoiser_model_type: str = None + denoiser_model: dict = field(default_factory=dict) + + # noise scheduler + noise_scheduler_type: str = None + noise_scheduler: dict = field(default_factory=dict) + + # denoise scheduler + denoise_scheduler_type: str = None + denoise_scheduler: dict = field(default_factory=dict) + + cfg: Config + + def configure(self): + super().configure() + + self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)( + self.cfg.shape_model + ) + self.shape_model.eval() + self.shape_model.requires_grad_(False) + + if self.cfg.visual_condition_type is not None: + self.visual_condition = step1x3d_geometry.find( + self.cfg.visual_condition_type + )(self.cfg.visual_condition) + + if self.cfg.caption_condition_type is not None: + self.caption_condition = step1x3d_geometry.find( + self.cfg.caption_condition_type + )(self.cfg.caption_condition) + + if self.cfg.label_condition_type is not None: + self.label_condition = step1x3d_geometry.find( + self.cfg.label_condition_type + )(self.cfg.label_condition) + + self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)( + self.cfg.denoiser_model + ) + + self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)( + **self.cfg.noise_scheduler + ) + + self.denoise_scheduler = step1x3d_geometry.find( + self.cfg.denoise_scheduler_type + )(**self.cfg.denoise_scheduler) + + def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]: + # 1. encode shape latents + if "sharp_surface" in batch.keys(): + sharp_surface = batch["sharp_surface"][ + ..., : 3 + self.cfg.shape_model.point_feats + ] + else: + sharp_surface = None + shape_embeds, kl_embed, _ = self.shape_model.encode( + batch["surface"][..., : 3 + self.cfg.shape_model.point_feats], + sample_posterior=True, + sharp_surface=sharp_surface, + ) + + latents = kl_embed * self.cfg.z_scale_factor + + # 2. gain visual condition + visual_cond_latents = None + if self.cfg.visual_condition_type is not None: + if "image" in batch and batch["image"].dim() == 5: + if self.training: + bs, n_images = batch["image"].shape[:2] + batch["image"] = batch["image"].view( + bs * n_images, *batch["image"].shape[-3:] + ) + else: + batch["image"] = batch["image"][:, 0, ...] + n_images = 1 + bs = batch["image"].shape[0] + visual_cond_latents = self.visual_condition(batch).to(latents) + latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1) + latents = latents.view(bs * n_images, *latents.shape[-2:]) + else: + visual_cond_latents = self.visual_condition(batch).to(latents) + + ## 2.1 text condition if provided + caption_cond_latents = None + if self.cfg.caption_condition_type is not None: + assert "caption" in batch.keys(), "caption is required for caption encoder" + assert bs == len( + batch["caption"] + ), "Batch size must be the same as the caption length." + caption_cond_latents = ( + self.caption_condition(batch) + .repeat_interleave(n_images, dim=0) + .to(latents) + ) + + ## 2.2 label condition if provided + label_cond_latents = None + if self.cfg.label_condition_type is not None: + assert "label" in batch.keys(), "label is required for label encoder" + assert bs == len( + batch["label"] + ), "Batch size must be the same as the label length." + label_cond_latents = ( + self.label_condition(batch) + .repeat_interleave(n_images, dim=0) + .to(latents) + ) + + # 3. sample noise that we"ll add to the latents + noise = torch.randn_like(latents).to( + latents + ) # [batch_size, n_token, latent_dim] + bs = latents.shape[0] + + # 4. Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.cfg.noise_scheduler.num_train_timesteps, + (bs,), + device=latents.device, + ) + timesteps = timesteps.long() + + # 5. add noise + noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # 6. diffusion model forward + output = self.denoiser_model( + noisy_z, + timesteps.long(), + visual_cond_latents, + caption_cond_latents, + label_cond_latents, + ).sample + + # 7. compute loss + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Prediction Type: {self.noise_scheduler.prediction_type} not supported." + ) + if self.cfg.snr_gamma == 0: + if self.cfg.loss.loss_type == "l1": + loss = F.l1_loss(output, target, reduction="mean") + elif self.cfg.loss.loss_type in ["mse", "l2"]: + loss = F.mse_loss(output, target, reduction="mean") + else: + raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(self.noise_scheduler, timesteps) + mse_loss_weights = torch.stack( + [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + if self.noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif self.noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + + if self.cfg.loss.loss_type == "l1": + loss = F.l1_loss(output, target, reduction="none") + elif self.cfg.loss.loss_type in ["mse", "l2"]: + loss = F.mse_loss(output, target, reduction="none") + else: + raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + return { + "loss_diffusion": loss, + "latents": latents, + "x_t": noisy_z, + "noise": noise, + "noise_pred": output, + "timesteps": timesteps, + } + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0.0 + for name, value in out.items(): + if name.startswith("loss_"): + self.log(f"train/{name}", value) + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + + for name, value in self.cfg.loss.items(): + if name.startswith("lambda_"): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + if self.cfg.skip_validation: + return {} + self.eval() + + if get_rank() == 0: + sample_inputs = json.loads( + open(self.cfg.val_samples_json).read() + ) # condition + sample_inputs_ = copy.deepcopy(sample_inputs) + sample_outputs = self.sample(sample_inputs) # list + for i, latents in enumerate(sample_outputs["latents"]): + meshes = self.shape_model.extract_geometry( + latents, + bounds=self.cfg.bounds, + mc_level=self.cfg.mc_level, + octree_resolution=self.cfg.octree_resolution, + enable_pbar=False, + ) + + for j in range(len(meshes)): + name = "" + if "image" in sample_inputs_: + name += ( + sample_inputs_["image"][j] + .split("/")[-1] + .replace(".png", "") + ) + elif "mvimages" in sample_inputs_: + name += ( + sample_inputs_["mvimages"][j][0] + .split("/")[-2] + .replace(".png", "") + ) + + if "caption" in sample_inputs_: + name += "_" + sample_inputs_["caption"][j].replace(" ", "_") + + if "label" in sample_inputs_: + name += ( + "_" + + sample_inputs_["label"][j]["symmetry"] + + sample_inputs_["label"][j]["edge_type"] + ) + + if ( + meshes[j].verts is not None + and meshes[j].verts.shape[0] > 0 + and meshes[j].faces is not None + and meshes[j].faces.shape[0] > 0 + ): + self.save_mesh( + f"it{self.true_global_step}/{name}_{i}.obj", + meshes[j].verts, + meshes[j].faces, + ) + torch.cuda.empty_cache() + + out = self(batch) + if self.global_step == 0: + latents = self.shape_model.decode(out["latents"]) + meshes = self.shape_model.extract_geometry( + latents, + bounds=self.cfg.bounds, + mc_level=self.cfg.mc_level, + octree_resolution=self.cfg.octree_resolution, + enable_pbar=False, + ) + + for i, mesh in enumerate(meshes): + self.save_mesh( + f"it{self.true_global_step}/{batch['uid'][i]}.obj", + mesh.verts, + mesh.faces, + ) + + return {"val/loss": out["loss_diffusion"]} + + @torch.no_grad() + def sample( + self, + sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + seed: Optional[int] = None, + **kwargs, + ): + + if steps is None: + steps = self.cfg.num_inference_steps + if guidance_scale is None: + guidance_scale = self.cfg.guidance_scale + do_classifier_free_guidance = guidance_scale != 1.0 + + # conditional encode + visal_cond = None + if "image" in sample_inputs: + sample_inputs["image"] = [ + Image.open(img) if type(img) == str else img + for img in sample_inputs["image"] + ] + sample_inputs["image"] = Step1X3DGeometryPipeline.preprocess_image( + sample_inputs["image"], **kwargs + ) + cond = self.visual_condition.encode_image(sample_inputs["image"]) + if do_classifier_free_guidance: + un_cond = self.visual_condition.empty_image_embeds.repeat( + len(sample_inputs["image"]), 1, 1 + ).to(cond) + visal_cond = torch.cat([un_cond, cond], dim=0) + caption_cond = None + if "caption" in sample_inputs: + cond = self.label_condition.encode_label(sample_inputs["caption"]) + if do_classifier_free_guidance: + un_cond = self.caption_condition.empty_caption_embeds.repeat( + len(sample_inputs["caption"]), 1, 1 + ).to(cond) + caption_cond = torch.cat([un_cond, cond], dim=0) + label_cond = None + if "label" in sample_inputs: + cond = self.label_condition.encode_label(sample_inputs["label"]) + if do_classifier_free_guidance: + un_cond = self.label_condition.empty_label_embeds.repeat( + len(sample_inputs["label"]), 1 + ).to(cond) + label_cond = torch.cat([un_cond, cond], dim=0) + + latents_list = [] + if seed != None: + generator = torch.Generator(device="cuda").manual_seed(seed) + else: + generator = None + + for _ in range(sample_times): + sample_loop = ddim_sample( + self.denoise_scheduler, + self.denoiser_model.eval(), + shape=self.shape_model.latent_shape, + visual_cond=visal_cond, + caption_cond=caption_cond, + label_cond=label_cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=False, + generator=generator, + ) + for sample, t in sample_loop: + latents = sample + latents_list.append(self.shape_model.decode(latents)) + + return {"latents": latents_list, "inputs": sample_inputs} + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + return diff --git a/step1x3d_geometry/systems/shape_rectified_flow.py b/step1x3d_geometry/systems/shape_rectified_flow.py new file mode 100755 index 0000000000000000000000000000000000000000..3121af38f9c8d56c5872a438ee900dbf3f32f9c4 --- /dev/null +++ b/step1x3d_geometry/systems/shape_rectified_flow.py @@ -0,0 +1,474 @@ +from dataclasses import dataclass, field + +import numpy as np +import json +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from skimage import measure +from einops import repeat +from tqdm import tqdm +from PIL import Image + +from diffusers import ( + DDPMScheduler, + DDIMScheduler, + UniPCMultistepScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler, +) +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +import step1x3d_geometry +from step1x3d_geometry.systems.base import BaseSystem +from step1x3d_geometry.utils.misc import get_rank +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.systems.utils import read_image, preprocess_image, flow_sample + + +def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=timesteps.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(timesteps.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +@step1x3d_geometry.register("rectified-flow-system") +class RectifiedFlowSystem(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + skip_validation: bool = True + val_samples_json: str = "" + bounds: float = 1.05 + mc_level: float = 0.0 + octree_resolution: int = 256 + + # diffusion config + guidance_scale: float = 7.5 + num_inference_steps: int = 30 + eta: float = 0.0 + snr_gamma: float = 5.0 + + # flow + weighting_scheme: str = "logit_normal" + logit_mean: float = 0 + logit_std: float = 1.0 + mode_scale: float = 1.29 + precondition_outputs: bool = True + precondition_t: int = 1000 + + # shape vae model + shape_model_type: str = None + shape_model: dict = field(default_factory=dict) + + # condition model + visual_condition_type: Optional[str] = None + visual_condition: dict = field(default_factory=dict) + caption_condition_type: Optional[str] = None + caption_condition: dict = field(default_factory=dict) + label_condition_type: Optional[str] = None + label_condition: dict = field(default_factory=dict) + + # diffusion model + denoiser_model_type: str = None + denoiser_model: dict = field(default_factory=dict) + + # noise scheduler + noise_scheduler_type: str = None + noise_scheduler: dict = field(default_factory=dict) + + # denoise scheduler + denoise_scheduler_type: str = None + denoise_scheduler: dict = field(default_factory=dict) + + # lora + use_lora: bool = False + lora_layers: Optional[str] = None + rank: int = 128 # The dimension of the LoRA update matrices. + alpha: int = 128 + + cfg: Config + + def configure(self): + super().configure() + + self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)( + self.cfg.shape_model + ) + self.shape_model.eval() + self.shape_model.requires_grad_(False) + + if self.cfg.visual_condition_type is not None: + self.visual_condition = step1x3d_geometry.find( + self.cfg.visual_condition_type + )(self.cfg.visual_condition) + self.visual_condition.requires_grad_(False) + + if self.cfg.caption_condition_type is not None: + self.caption_condition = step1x3d_geometry.find( + self.cfg.caption_condition_type + )(self.cfg.caption_condition) + self.caption_condition.requires_grad_(False) + + if self.cfg.label_condition_type is not None: + self.label_condition = step1x3d_geometry.find( + self.cfg.label_condition_type + )(self.cfg.label_condition) + + self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)( + self.cfg.denoiser_model + ) + if self.cfg.use_lora: # We only train the additional adapter LoRA layers + self.denoiser_model.requires_grad_(False) + + self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)( + **self.cfg.noise_scheduler + ) + self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) + + self.denoise_scheduler = step1x3d_geometry.find( + self.cfg.denoise_scheduler_type + )(**self.cfg.denoise_scheduler) + + if self.cfg.use_lora: + from peft import LoraConfig, set_peft_model_state_dict + + if self.cfg.lora_layers is not None: + self.target_modules = [ + layer.strip() for layer in self.cfg.lora_layers.split(",") + ] + else: + self.target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + self.transformer_lora_config = LoraConfig( + r=self.cfg.rank, + lora_alpha=self.cfg.alpha, + init_lora_weights="gaussian", + target_modules=self.target_modules, + ) + self.denoiser_model.dit_model.add_adapter(self.transformer_lora_config) + + def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]: + # 1. encode shape latents + if "sharp_surface" in batch.keys(): + sharp_surface = batch["sharp_surface"][ + ..., : 3 + self.cfg.shape_model.point_feats + ] + else: + sharp_surface = None + shape_embeds, latents, _ = self.shape_model.encode( + batch["surface"][..., : 3 + self.cfg.shape_model.point_feats], + sample_posterior=True, + sharp_surface=sharp_surface, + ) + + # 2. gain visual condition + visual_cond = None + if self.cfg.visual_condition_type is not None: + assert "image" in batch.keys(), "image is required for label encoder" + if "image" in batch and batch["image"].dim() == 5: + if self.training: + bs, n_images = batch["image"].shape[:2] + batch["image"] = batch["image"].view( + bs * n_images, *batch["image"].shape[-3:] + ) + else: + batch["image"] = batch["image"][:, 0, ...] + n_images = 1 + bs = batch["image"].shape[0] + visual_cond = self.visual_condition(batch).to(latents) + latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1) + latents = latents.view(bs * n_images, *latents.shape[-2:]) + else: + visual_cond = self.visual_condition(batch).to(latents) + bs = visual_cond.shape[0] + n_images = 1 + + ## 2.1 text condition if provided + caption_cond = None + if self.cfg.caption_condition_type is not None: + assert "caption" in batch.keys(), "caption is required for caption encoder" + assert bs == len( + batch["caption"] + ), "Batch size must be the same as the caption length." + caption_cond = ( + self.caption_condition(batch) + .repeat_interleave(n_images, dim=0) + .to(latents) + ) + + ## 2.2 label condition if provided + label_cond = None + if self.cfg.label_condition_type is not None: + assert "label" in batch.keys(), "label is required for label encoder" + assert bs == len( + batch["label"] + ), "Batch size must be the same as the label length." + label_cond = ( + self.label_condition(batch) + .repeat_interleave(n_images, dim=0) + .to(latents) + ) + + # 3. sample noise that we"ll add to the latents + noise = torch.randn_like(latents).to( + latents + ) # [batch_size, n_token, latent_dim] + + # 4. Sample a random timestep + u = compute_density_for_timestep_sampling( + weighting_scheme=self.cfg.weighting_scheme, + batch_size=bs * n_images, + logit_mean=self.cfg.logit_mean, + logit_std=self.cfg.logit_std, + mode_scale=self.cfg.mode_scale, + ) + indices = (u * self.cfg.noise_scheduler.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to( + device=latents.device + ) + + # 5. add noise + sigmas = get_sigmas( + self.noise_scheduler_copy, timesteps, n_dim=3, dtype=latents.dtype + ) + noisy_z = (1.0 - sigmas) * latents + sigmas * noise + + # 6. diffusion model forward + output = self.denoiser_model( + noisy_z, timesteps.long(), visual_cond, caption_cond, label_cond + ).sample + + # 7. compute loss + if self.cfg.precondition_outputs: + output = output * (-sigmas) + noisy_z + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=self.cfg.weighting_scheme, sigmas=sigmas + ) + # flow matching loss + if self.cfg.precondition_outputs: + target = latents + else: + target = noise - latents + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (output.float() - target.float()) ** 2).reshape( + target.shape[0], -1 + ), + 1, + ) + loss = loss.mean() + + return { + "loss_diffusion": loss, + "latents": latents, + "x_t": noisy_z, + "noise": noise, + "noise_pred": output, + "timesteps": timesteps, + } + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0.0 + for name, value in out.items(): + if name.startswith("loss_"): + self.log(f"train/{name}", value) + loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) + if name.startswith("log_"): + self.log(f"log/{name.replace('log_', '')}", value.mean()) + + for name, value in self.cfg.loss.items(): + if name.startswith("lambda_"): + self.log(f"train_params/{name}", self.C(value)) + + return {"loss": loss} + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + if self.cfg.skip_validation: + return {} + self.eval() + + if get_rank() == 0: + sample_inputs = json.loads( + open(self.cfg.val_samples_json).read() + ) # condition + sample_inputs_ = copy.deepcopy(sample_inputs) + sample_outputs = self.sample(sample_inputs) # list + for i, latents in enumerate(sample_outputs["latents"]): + meshes = self.shape_model.extract_geometry( + latents, + bounds=self.cfg.bounds, + mc_level=self.cfg.mc_level, + octree_resolution=self.cfg.octree_resolution, + enable_pbar=False, + ) + + for j in range(len(meshes)): + name = "" + if "image" in sample_inputs_: + name += ( + sample_inputs_["image"][j] + .split("/")[-1] + .replace(".png", "") + ) + + elif "mvimages" in sample_inputs_: + name += ( + sample_inputs_["mvimages"][j][0] + .split("/")[-2] + .replace(".png", "") + ) + + if "caption" in sample_inputs_: + name += "_" + sample_inputs_["caption"][j].replace( + " ", "_" + ).replace(".", "") + + if "label" in sample_inputs_: + name += ( + "_" + + sample_inputs_["label"][j]["symmetry"] + + sample_inputs_["label"][j]["edge_type"] + ) + + if ( + meshes[j].verts is not None + and meshes[j].verts.shape[0] > 0 + and meshes[j].faces is not None + and meshes[j].faces.shape[0] > 0 + ): + self.save_mesh( + f"it{self.true_global_step}/{name}_{i}.obj", + meshes[j].verts, + meshes[j].faces, + ) + torch.cuda.empty_cache() + + out = self(batch) + if self.global_step == 0: + latents = self.shape_model.decode(out["latents"]) + meshes = self.shape_model.extract_geometry( + latents, + bounds=self.cfg.bounds, + mc_level=self.cfg.mc_level, + octree_resolution=self.cfg.octree_resolution, + enable_pbar=False, + ) + + for i, mesh in enumerate(meshes): + self.save_mesh( + f"it{self.true_global_step}/{batch['uid'][i]}.obj", + mesh.verts, + mesh.faces, + ) + + return {"val/loss": out["loss_diffusion"]} + + @torch.no_grad() + def sample( + self, + sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + seed: Optional[int] = None, + **kwargs, + ): + + if steps is None: + steps = self.cfg.num_inference_steps + if guidance_scale is None: + guidance_scale = self.cfg.guidance_scale + do_classifier_free_guidance = guidance_scale != 1.0 + + # conditional encode + visal_cond = None + if "image" in sample_inputs: + sample_inputs["image"] = [ + Image.open(img) if type(img) == str else img + for img in sample_inputs["image"] + ] + sample_inputs["image"] = preprocess_image(sample_inputs["image"], **kwargs) + cond = self.visual_condition.encode_image(sample_inputs["image"]) + if do_classifier_free_guidance: + un_cond = self.visual_condition.empty_image_embeds.repeat( + len(sample_inputs["image"]), 1, 1 + ).to(cond) + visal_cond = torch.cat([un_cond, cond], dim=0) + caption_cond = None + if "caption" in sample_inputs: + cond = self.label_condition.encode_label(sample_inputs["caption"]) + if do_classifier_free_guidance: + un_cond = self.caption_condition.empty_caption_embeds.repeat( + len(sample_inputs["caption"]), 1, 1 + ).to(cond) + caption_cond = torch.cat([un_cond, cond], dim=0) + label_cond = None + if "label" in sample_inputs: + cond = self.label_condition.encode_label(sample_inputs["label"]) + if do_classifier_free_guidance: + un_cond = self.label_condition.empty_label_embeds.repeat( + len(sample_inputs["label"]), 1, 1 + ).to(cond) + label_cond = torch.cat([un_cond, cond], dim=0) + + latents_list = [] + if seed != None: + generator = torch.Generator(device="cuda").manual_seed(seed) + else: + generator = None + + for _ in range(sample_times): + sample_loop = flow_sample( + self.denoise_scheduler, + self.denoiser_model.eval(), + shape=self.shape_model.latent_shape, + visual_cond=visal_cond, + caption_cond=caption_cond, + label_cond=label_cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=False, + generator=generator, + ) + for sample, t in sample_loop: + latents = sample + latents_list.append(self.shape_model.decode(latents)) + + return {"latents": latents_list, "inputs": sample_inputs} + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + return diff --git a/step1x3d_geometry/systems/utils.py b/step1x3d_geometry/systems/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..71d5a9862a52dc8a6d6998be7615520fd779f8e7 --- /dev/null +++ b/step1x3d_geometry/systems/utils.py @@ -0,0 +1,391 @@ +import torch +import numpy as np + +import rembg +from PIL import Image +from tqdm import tqdm +from diffusers import DDIMScheduler +from torchvision import transforms + +from step1x3d_geometry.utils.typing import * +from step1x3d_geometry.utils.misc import get_device + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@torch.no_grad() +def ddim_sample( + ddim_scheduler: DDIMScheduler, + diffusion_model: torch.nn.Module, + shape: Union[List[int], Tuple[int]], + visual_cond: torch.FloatTensor, + caption_cond: torch.FloatTensor, + label_cond: torch.FloatTensor, + steps: int, + eta: float = 0.0, + guidance_scale: float = 3.0, + do_classifier_free_guidance: bool = True, + generator: Optional[torch.Generator] = None, + device: torch.device = "cuda:0", + disable_prog: bool = True, +): + + assert steps > 0, f"{steps} must > 0." + + # init latents + if visual_cond is not None: + bsz = visual_cond.shape[0] + device = visual_cond.device + dtype = visual_cond.dtype + if caption_cond is not None: + bsz = caption_cond.shape[0] + device = caption_cond.device + dtype = caption_cond.dtype + if label_cond is not None: + bsz = label_cond.shape[0] + device = label_cond.device + dtype = label_cond.dtype + + if do_classifier_free_guidance: + bsz = bsz // 2 + latents = torch.randn( + (bsz, *shape), + generator=generator, + device=device, + dtype=dtype, + ) + try: + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + except AttributeError: + pass + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + extra_step_kwargs = {"generator": generator} + + # set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + steps, + device, + ) + if eta > 0: + assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}." + assert ( + scheduler.__class__.__name__ == "DDIMScheduler" + ), f"eta is only used with the DDIMScheduler." + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs["eta"] = eta + + # reverse + for i, t in enumerate( + tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False) + ): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + + # predict the noise residual + timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) + timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) + noise_pred = diffusion_model.forward( + latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = ddim_scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + yield latents, t + + +@torch.no_grad() +def flow_sample( + scheduler: DDIMScheduler, + diffusion_model: torch.nn.Module, + shape: Union[List[int], Tuple[int]], + visual_cond: torch.FloatTensor, + caption_cond: torch.FloatTensor, + label_cond: torch.FloatTensor, + steps: int, + eta: float = 0.0, + guidance_scale: float = 3.0, + do_classifier_free_guidance: bool = True, + generator: Optional[torch.Generator] = None, + device: torch.device = "cuda:0", + disable_prog: bool = True, +): + + assert steps > 0, f"{steps} must > 0." + + # init latents + if visual_cond is not None: + bsz = visual_cond.shape[0] + device = visual_cond.device + dtype = visual_cond.dtype + if caption_cond is not None: + bsz = caption_cond.shape[0] + device = caption_cond.device + dtype = caption_cond.dtype + if label_cond is not None: + bsz = label_cond.shape[0] + device = label_cond.device + dtype = label_cond.dtype + + if do_classifier_free_guidance: + bsz = bsz // 2 + latents = torch.randn( + (bsz, *shape), + generator=generator, + device=device, + dtype=dtype, + ) + try: + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + except AttributeError: + pass + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + extra_step_kwargs = {"generator": generator} + + # set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + steps + 1, + device, + ) + if eta > 0: + assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}." + assert ( + scheduler.__class__.__name__ == "DDIMScheduler" + ), f"eta is only used with the DDIMScheduler." + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs["eta"] = eta + + # reverse + distance = (timesteps[:-1] - timesteps[1:]) / scheduler.config.num_train_timesteps + for i, t in enumerate( + tqdm(timesteps[:-1], disable=disable_prog, desc="Flow Sampling:", leave=False) + ): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + # predict the noise residual + timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=device) + timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) + noise_pred = diffusion_model.forward( + latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond + ).sample + if isinstance(noise_pred, tuple): + noise_pred, layer_idx_list, ones_list, pred_c_list = noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = latents - distance[i] * noise_pred + + yield latents, t + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def read_image(img, img_size=224): + transform = transforms.Compose( + [ + transforms.Resize( + img_size, transforms.InterpolationMode.BICUBIC, antialias=True + ), + transforms.CenterCrop(img_size), # crop a (224, 224) square + transforms.ToTensor(), + ] + ) + rgb = Image.open(img) + rgb = transform(rgb)[:3, ...].permute(1, 2, 0) + return rgb + + +def preprocess_image( + images_pil: List[Image.Image], + force: bool = False, + background_color: List[int] = [255, 255, 255], + foreground_ratio: float = 0.95, +): + r""" + Crop and remote the background of the input image + Args: + image_pil (`List[PIL.Image.Image]`): + List of `PIL.Image.Image` objects representing the input image. + force (`bool`, *optional*, defaults to `False`): + Whether to force remove the background even if the image has an alpha channel. + Returns: + `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image. + """ + preprocessed_images = [] + for i in range(len(images_pil)): + image = images_pil[i] + width, height, size = image.width, image.height, image.size + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print( + "alhpa channl not empty, skip remove background, using alpha channel as mask" + ) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image) + + # calculate the min bbox of the image + alpha = image.split()[-1] + bboxs = alpha.getbbox() + x1, y1, x2, y2 = bboxs + dy, dx = y2 - y1, x2 - x1 + s = min(height * foreground_ratio / dy, width * foreground_ratio / dx) + Ht, Wt = int(dy * s), int(dx * s) + + background = Image.new("RGBA", image.size, (*background_color, 255)) + image = Image.alpha_composite(background, image) + image = image.crop(alpha.getbbox()) + alpha = alpha.crop(alpha.getbbox()) + + # Calculate the new size after rescaling + new_size = tuple(int(dim * foreground_ratio) for dim in size) + # Resize the image while maintaining the aspect ratio + resized_image = image.resize((Wt, Ht)) + resized_alpha = alpha.resize((Wt, Ht)) + # Create a new image with the original size and white background + padded_image = Image.new("RGB", size, tuple(background_color)) + padded_alpha = Image.new("L", size, (0)) + paste_position = ( + (width - resized_image.width) // 2, + (height - resized_image.height) // 2, + ) + padded_image.paste(resized_image, paste_position) + padded_alpha.paste(resized_alpha, paste_position) + + # expand image to 1:1 + width, height = padded_image.size + if width == height: + padded_image.putalpha(padded_alpha) + preprocessed_images.append(padded_image) + continue + new_size = (max(width, height), max(width, height)) + new_image = Image.new("RGB", new_size, tuple(background_color)) + new_alpha = Image.new("L", new_size, (0)) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + new_image.paste(padded_image, paste_position) + new_alpha.paste(padded_alpha, paste_position) + new_image.putalpha(new_alpha) + preprocessed_images.append(new_image) + + return preprocessed_images diff --git a/step1x3d_geometry/utils/__init__.py b/step1x3d_geometry/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0e44449338cf3ff3bb3d124ef22c8f7bbca760a5 --- /dev/null +++ b/step1x3d_geometry/utils/__init__.py @@ -0,0 +1 @@ +from . import base diff --git a/step1x3d_geometry/utils/base.py b/step1x3d_geometry/utils/base.py new file mode 100755 index 0000000000000000000000000000000000000000..bd022bd9da644427cddda98a5863dbf449ce9fd7 --- /dev/null +++ b/step1x3d_geometry/utils/base.py @@ -0,0 +1,215 @@ +from dataclasses import dataclass + +import os +import copy +import json +from omegaconf import OmegaConf +import torch +import torch.nn as nn + +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import ( + extract_commit_hash, +) + +from step1x3d_geometry.utils.config import parse_structured +from step1x3d_geometry.utils.misc import get_device, load_module_weights +from step1x3d_geometry.utils.typing import * + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(ModelMixin, Updateable, nn.Module): + @dataclass + class Config: + weights: Optional[str] = None + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + config_name = "config.json" + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + # self.device = get_device() + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict, epoch, global_step = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict) + self.do_update_step( + epoch, global_step, on_load_weights=True + ) # restore states + # dummy tensor to indicate model state + self._dummy: Float[Tensor, "..."] + self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) + + def configure(self, *args, **kwargs) -> None: + pass + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ): + subfolder = kwargs.pop("subfolder", None) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + # Load from a PyTorch checkpoint + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + raise ValueError + + config_dict = json.load(open(config_file, "r")) + commit_hash = extract_commit_hash(config_file) + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @classmethod + def from_config(cls, config: Dict[str, Any] = None, **kwargs): + model = cls(config) + return model + + def register_to_config(self, **kwargs): + pass + + def save_config(self, save_directory: Union[str, os.PathLike], **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + config_dict = OmegaConf.to_container(self.cfg, resolve=True) + for k in copy.deepcopy(config_dict).keys(): + if k.startswith("pretrained"): + config_dict.pop(k) + config_dict.pop("weights") + with open(output_config_file, "w", encoding="utf-8") as f: + json.dump(config_dict, f, ensure_ascii=False, indent=4) + + print(f"Configuration saved in {output_config_file}") diff --git a/step1x3d_geometry/utils/callbacks.py b/step1x3d_geometry/utils/callbacks.py new file mode 100755 index 0000000000000000000000000000000000000000..71747a44ca32d88ac71a11320d7c3f1a96bb4cce --- /dev/null +++ b/step1x3d_geometry/utils/callbacks.py @@ -0,0 +1,176 @@ +import os +import shutil +import subprocess + +import pytorch_lightning + +from step1x3d_geometry.utils.config import dump_config +from step1x3d_geometry.utils.misc import parse_version + +if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn + + +class EarlyEnvironmentSetter(Callback): + def __init__(self): + super().__init__() + self.rank_set = False + + def setup(self, trainer, pl_module, stage): + if not self.rank_set: + world_size = trainer.num_devices + local_rank = trainer.strategy.local_rank + + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str(local_rank) + + self.rank_set = True + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join( + self.save_root, + ( + self.version + if isinstance(self.version, str) + else f"version_{self.version}" + ), + ) + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() + for b in set( + subprocess.check_output( + 'git ls-files -- ":!:load/*"', shell=True + ).splitlines() + ) + | set( # hard code, TODO: use config to exclude folders or files + subprocess.check_output( + "git ls-files --others --exclude-standard", shell=True + ).splitlines() + ) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn( + "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." + ) + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config_path, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config_path = config_path + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) + shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items + + +class ProgressCallback(Callback): + def __init__(self, save_path): + super().__init__() + self.save_path = save_path + self._file_handle = None + + @property + def file_handle(self): + if self._file_handle is None: + self._file_handle = open(self.save_path, "w") + return self._file_handle + + @rank_zero_only + def write(self, msg: str) -> None: + self.file_handle.seek(0) + self.file_handle.truncate() + self.file_handle.write(msg) + self.file_handle.flush() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): + self.write( + f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" + ) + + @rank_zero_only + def on_validation_start(self, trainer, pl_module): + self.write(f"Rendering validation image ...") + + @rank_zero_only + def on_test_start(self, trainer, pl_module): + self.write(f"Rendering video ...") + + @rank_zero_only + def on_predict_start(self, trainer, pl_module): + self.write(f"Exporting mesh assets ...") diff --git a/step1x3d_geometry/utils/checkpoint.py b/step1x3d_geometry/utils/checkpoint.py new file mode 100755 index 0000000000000000000000000000000000000000..7a1a8f5f49aa61ccbf5e82410bf1fc0c72b61f67 --- /dev/null +++ b/step1x3d_geometry/utils/checkpoint.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from step1x3d_geometry.utils.typing import * + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False, +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + import deepspeed + + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/step1x3d_geometry/utils/config.py b/step1x3d_geometry/utils/config.py new file mode 100755 index 0000000000000000000000000000000000000000..40d0efc4239ae796ad0721f96878894cb0879d48 --- /dev/null +++ b/step1x3d_geometry/utils/config.py @@ -0,0 +1,128 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +from omegaconf import OmegaConf + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: str(s).replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) +OmegaConf.register_new_resolver( + "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 +) +# ======================================================= # + + +def C_max(value: Any) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) >= 6: + max_value = value[2] + for i in range(4, len(value), 2): + max_value = max(max_value, value[i]) + value = [value[0], value[1], max_value, value[3]] + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + value = max(start_value, end_value) + return value + + +@dataclass +class ExperimentConfig: + name: str = "default" + description: str = "" + tag: str = "" + seed: int = 0 + use_timestamp: bool = True + timestamp: Optional[str] = None + exp_root_dir: str = "outputs" + + ### these shouldn't be set manually + exp_dir: str = "outputs/default" + trial_name: str = "exp" + trial_dir: str = "outputs/default/exp" + n_gpus: int = 1 + ### + + resume: Optional[str] = None + + data_type: str = "" + data: dict = field(default_factory=dict) + + system_type: str = "" + system: dict = field(default_factory=dict) + + # accept pytorch-lightning trainer parameters + # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api + trainer: dict = field(default_factory=dict) + + # accept pytorch-lightning checkpoint callback parameters + # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint + checkpoint: dict = field(default_factory=dict) + + def __post_init__(self): + if not self.tag and not self.use_timestamp: + raise ValueError("Either tag is specified or use_timestamp is True.") + self.trial_name = self.tag + # if resume from an existing config, self.timestamp should not be None + if self.timestamp is None: + self.timestamp = "" + if self.use_timestamp: + if self.n_gpus > 1: + step1x3d_geometry.warn( + "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." + ) + else: + self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") + self.trial_name += self.timestamp + self.exp_dir = os.path.join(self.exp_root_dir, self.name) + self.trial_dir = os.path.join(self.exp_dir, self.trial_name) + # os.makedirs(self.trial_dir, exist_ok=True) + + +def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: + if from_string: + yaml_confs = [OmegaConf.create(s) for s in yamls] + else: + yaml_confs = [OmegaConf.load(f) for f in yamls] + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg = parse_structured(ExperimentConfig, cfg) + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.structured(fields(**cfg)) + return scfg diff --git a/step1x3d_geometry/utils/ema.py b/step1x3d_geometry/utils/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..c29320b5928d79e16272c8b78fee028b54285cf4 --- /dev/null +++ b/step1x3d_geometry/utils/ema.py @@ -0,0 +1,305 @@ +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): +# ------------------------------------------------------------------------------------------------------------------------------------- + +import torch +import os +import os.path +import warnings + +import pytorch_lightning as pl +from torch import Tensor +from pytorch_lightning import Callback +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from typing import Any, Dict, List, Optional + +try: + import amp_C + + apex_available = True +except Exception: + apex_available = False + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + apply_ema_every_n_steps: Apply EMA every n global steps. + start_step: Start applying EMA from ``start_step`` global step onwards. + save_ema_weights_in_callback_state: Enable saving EMA weights in callback state. + evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. + Note this means that when saving the model, the validation metrics are calculated with the EMA weights. + + Adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py + """ + + def __init__( + self, + decay: float = 0.999, + apply_ema_every_n_steps: int = 1, + start_step: int = 0, + # else .ckpt will save a model weights copy in key 'callback' + save_ema_weights_in_callback_state: bool = False, + evaluate_ema_weights_instead: bool = True, + ): + if not apex_available: + rank_zero_warn( + "EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." + ) + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self._ema_model_weights: Optional[List[torch.Tensor]] = None + self._overflow_buf: Optional[torch.Tensor] = None + self._cur_step: Optional[int] = None + self._weights_buffer: Optional[List[torch.Tensor]] = None + self.apply_ema_every_n_steps = apply_ema_every_n_steps + self.start_step = start_step + self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state + self.evaluate_ema_weights_instead = evaluate_ema_weights_instead + self.decay = decay + + def on_train_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + rank_zero_info("Creating EMA weights copy.") + if self._ema_model_weights is None: + self._ema_model_weights = [ + p.detach().clone() for p in pl_module.state_dict().values() + ] + # ensure that all the weights are on the correct device + self._ema_model_weights = [ + p.to(pl_module.device) for p in self._ema_model_weights + ] + self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) + + def ema(self, pl_module: "pl.LightningModule") -> None: + if apex_available and pl_module.device.type == "cuda": + return self.apply_multi_tensor_ema(pl_module) + return self.apply_ema(pl_module) + + def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: + model_weights = list(pl_module.state_dict().values()) + amp_C.multi_tensor_axpby( + 65536, + self._overflow_buf, + [self._ema_model_weights, model_weights, self._ema_model_weights], + self.decay, + 1 - self.decay, + -1, + ) + + def apply_ema(self, pl_module: "pl.LightningModule") -> None: + for orig_weight, ema_weight in zip( + list(pl_module.state_dict().values()), self._ema_model_weights + ): + if ( + ema_weight.data.dtype != torch.long + and orig_weight.data.dtype != torch.long + ): + # ensure that non-trainable parameters (e.g., feature distributions) are not included in EMA weight averaging + diff = ema_weight.data - orig_weight.data + diff.mul_(1.0 - self.decay) + ema_weight.sub_(diff) + + def should_apply_ema(self, step: int) -> bool: + return ( + step != self._cur_step + and step >= self.start_step + and step % self.apply_ema_every_n_steps == 0 + ) + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if self.should_apply_ema(trainer.global_step): + self._cur_step = trainer.global_step + self.ema(pl_module) + + def state_dict(self) -> Dict[str, Any]: + if self.save_ema_weights_in_callback_state: + return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) + return dict(cur_step=self._cur_step) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._cur_step = state_dict["cur_step"] + # when loading within apps such as NeMo, EMA weights will be loaded by the experiment manager separately + if self._ema_model_weights is None: + self._ema_model_weights = state_dict.get("ema_weights") + + def on_load_checkpoint( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any], + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + if trainer.ckpt_path and checkpoint_callback is not None: + ext = checkpoint_callback.FILE_EXTENSION + if trainer.ckpt_path.endswith(f"-EMA{ext}"): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = trainer.ckpt_path.replace(ext, f"-EMA{ext}") + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu")) + self._ema_model_weights = ema_state_dict["state_dict"].values() + del ema_state_dict + rank_zero_info( + "EMA weights have been loaded successfully. Continuing training with saved EMA weights." + ) + else: + warnings.warn( + "we were unable to find the associated EMA weights when re-loading, " + "training will start with new EMA weights.", + UserWarning, + ) + + def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: + self._weights_buffer = [ + p.detach().clone().to("cpu") for p in pl_module.state_dict().values() + ] + new_state_dict = { + k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights) + } + pl_module.load_state_dict(new_state_dict) + + def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: + state_dict = pl_module.state_dict() + new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} + pl_module.load_state_dict(new_state_dict) + del self._weights_buffer + + @property + def ema_initialized(self) -> bool: + return self._ema_model_weights is not None + + def on_validation_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_validation_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) + + def on_test_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_test_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) + + +class EMAModelCheckpoint(ModelCheckpoint): + """ + Light wrapper around Lightning's `ModelCheckpoint` to, upon request, save an EMA copy of the model as well. + + Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744 + """ + + def __init__(self, **kwargs): + # call the parent class constructor with the provided kwargs + super().__init__(**kwargs) + + def _get_ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]: + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + ema_callback = self._get_ema_callback(trainer) + if ema_callback is not None: + # save EMA copy of the model as well + ema_callback.replace_model_weights(trainer.lightning_module) + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + os.makedirs(os.path.dirname(filepath), exist_ok=True) + super()._save_checkpoint(trainer, filepath) + ema_callback.restore_original_weights(trainer.lightning_module) + + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}") + + # only change the last line + def _update_best_and_save( + self, + current: Tensor, + trainer: "pl.Trainer", + monitor_candidates: Dict[str, Tensor], + ) -> None: + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k + + del_filepath = None + if len(self.best_k_models) == k and k > 0: + del_filepath = self.kth_best_model_path + self.best_k_models.pop(del_filepath) + + # do not save nan, replace with +/- inf + if isinstance(current, Tensor) and torch.isnan(current): + current = torch.tensor( + float("inf" if self.mode == "min" else "-inf"), device=current.device + ) + + filepath = self._get_metric_interpolated_filepath_name( + monitor_candidates, trainer, del_filepath + ) + + # save the current score + self.current_score = current + self.best_k_models[filepath] = current + + if len(self.best_k_models) == k: + # monitor dict has reached k elements + _op = max if self.mode == "min" else min + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] + self.kth_value = self.best_k_models[self.kth_best_model_path] + + _op = min if self.mode == "min" else max + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] + self.best_model_score = self.best_k_models[self.best_model_path] + + if self.verbose: + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info( + f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" + f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" + ) + self._save_checkpoint(trainer, filepath) + + if del_filepath is not None and filepath != del_filepath: + self._remove_checkpoint(trainer, del_filepath) + self._remove_checkpoint( + trainer, + del_filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}"), + ) diff --git a/step1x3d_geometry/utils/misc.py b/step1x3d_geometry/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..1b63a1c52eb1dbafa618cf1c5c28a4a5de133ea1 --- /dev/null +++ b/step1x3d_geometry/utils/misc.py @@ -0,0 +1,165 @@ +import gc +import os +import re + +import torch +import torch.distributed as dist +from packaging import version + +from step1x3d_geometry.utils.config import config_to_primitive +from step1x3d_geometry.utils.typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_world_size(): + world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") + for key in world_size_keys: + world_size = os.environ.get(key) + if world_size is not None: + return int(world_size) + return 1 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] + + +def C(value: Any, epoch: int, global_step: int) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = global_step + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + elif isinstance(end_step, float): + current_step = epoch + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + return value + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() + + +def finish_with_cleanup(func: Callable): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + cleanup() + return out + + return wrapper + + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def barrier(): + if not _distributed_available(): + return + else: + torch.distributed.barrier() + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + + +def enable_gradient(model, enabled: bool = True) -> None: + for param in model.parameters(): + param.requires_grad_(enabled) + + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + if isinstance(tensors, list): + return tensors + return tensors + if not isinstance(tensors, list): + is_list = False + tensors = [tensors] + else: + is_list = True + output_tensor = [] + tensor_list = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_all, tensor, async_op=False) # performance opt + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + if not is_list: + return output_tensor[0] + return output_tensor diff --git a/step1x3d_geometry/utils/ops.py b/step1x3d_geometry/utils/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..e12acb6e51b4711be97ae35784cea02d53846517 --- /dev/null +++ b/step1x3d_geometry/utils/ops.py @@ -0,0 +1,180 @@ +import math +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import step1x3d_geometry +from step1x3d_geometry.utils.typing import * + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = ( + generator.device.type + if not isinstance(generator, list) + else generator[0].device.type + ) + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + print( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError( + f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." + ) + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn( + shape, + generator=generator[i], + device=rand_device, + dtype=dtype, + layout=layout, + ) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn( + shape, generator=generator, device=rand_device, dtype=dtype, layout=layout + ).to(device) + + return latents + + +def generate_dense_grid_points( + bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij" +): + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length diff --git a/step1x3d_geometry/utils/saving.py b/step1x3d_geometry/utils/saving.py new file mode 100755 index 0000000000000000000000000000000000000000..0a410bc98e7cdd388d14c3791f2fe86fb6ae9c0d --- /dev/null +++ b/step1x3d_geometry/utils/saving.py @@ -0,0 +1,463 @@ +import json +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision.utils as vutils +import trimesh +import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw +from pytorch_lightning.loggers import WandbLogger + +from step1x3d_geometry.utils.typing import * + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[WandbLogger] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + def create_loggers(self, cfg_loggers: DictConfig) -> None: + if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable: + self._wandb_logger = WandbLogger( + project=cfg_loggers.wandb.project, name=cfg_loggers.wandb.name + ) + + def get_loggers(self) -> List: + if self._wandb_logger: + return [self._wandb_logger] + else: + return [] + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + ( + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + w = max([col.shape[1] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + w = min([col.shape[1] for col in cols]) + elif isinstance(align, int): + h = align + w = align + elif ( + isinstance(align, tuple) + and isinstance(align[0], int) + and isinstance(align[1], int) + ): + h, w = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, int or (int, int)" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h or cols[i].shape[1] != w: + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + wandb.log({name: wandb.Image(save_path), "trainer/global_step": step}) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_image_vutils(self, filename, img) -> str: + save_path = self.get_save_path(filename) + vutils.save_image(img, save_path) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Video(save_path, format="mp4"), + "trainer/global_step": step, + } + ) + return save_path + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: + save_path = self.get_save_path(filename) + v_pos = self.convert_data(v_pos) + t_pos_idx = self.convert_data(t_pos_idx) + mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) + mesh.export(save_path) + return save_path + + def save_file(self, filename, src_path) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + return save_path + + def save_txt(self, filename, comment) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(comment) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path diff --git a/step1x3d_geometry/utils/scheduler.py b/step1x3d_geometry/utils/scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..ef03eac2179a3d5026bded3b4c6c84c82cb06050 --- /dev/null +++ b/step1x3d_geometry/utils/scheduler.py @@ -0,0 +1,108 @@ +import sys +import warnings +from bisect import bisect_right + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +import step1x3d_geometry + + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split("."): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, "params"): + params = [ + {"params": get_parameters(model, name), "name": name, **args} + for name, args in config.params.items() + ] + step1x3d_geometry.debug(f"Specify optimizer params: {config.params}") + else: + if hasattr(config, "only_requires_grad") and config.only_requires_grad: + params = list(filter(lambda p: p.requires_grad, model.parameters())) + else: + params = model.parameters() + + if config.name in ["FusedAdam"]: + import apex + + optim = getattr(apex.optimizers, config.name)(params, **config.args) + elif config.name in ["Prodigy"]: + import prodigyopt + + optim = getattr(prodigyopt, config.name)(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler_to_instance(config, optimizer): + if config.name == "ChainedScheduler": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.ChainedScheduler(schedulers) + elif config.name == "Sequential": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=config.milestones + ) + else: + scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) + return scheduler + + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler diff --git a/step1x3d_geometry/utils/typing.py b/step1x3d_geometry/utils/typing.py new file mode 100755 index 0000000000000000000000000000000000000000..21b1bb2dda7aac2ccc9f26fb47242892b470e671 --- /dev/null +++ b/step1x3d_geometry/utils/typing.py @@ -0,0 +1,41 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, + Sequence, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/step1x3d_texture/custom_rasterizer/custom_rasterizer/__init__.py b/step1x3d_texture/custom_rasterizer/custom_rasterizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b85df7af048a36b861910a04ff63356da28a6a --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/custom_rasterizer/__init__.py @@ -0,0 +1,22 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +""" +from .hierarchy import BuildHierarchy, BuildHierarchyWithColor +from .io_obj import LoadObj, LoadObjWithTexture +from .render import rasterize, interpolate +""" +from .io_glb import * +from .io_obj import * +from .render import * diff --git a/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_glb.py b/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_glb.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ce7909ae14237ac070f7a99de7b2bcba64739f --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_glb.py @@ -0,0 +1,276 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import base64 +import io +import os + +import numpy as np +from PIL import Image as PILImage +from pygltflib import GLTF2 +from scipy.spatial.transform import Rotation as R + + +# Function to extract buffer data +def get_buffer_data(gltf, buffer_view): + buffer = gltf.buffers[buffer_view.buffer] + buffer_data = gltf.get_data_from_buffer_uri(buffer.uri) + byte_offset = buffer_view.byteOffset if buffer_view.byteOffset else 0 + byte_length = buffer_view.byteLength + return buffer_data[byte_offset : byte_offset + byte_length] + + +# Function to extract attribute data +def get_attribute_data(gltf, accessor_index): + accessor = gltf.accessors[accessor_index] + buffer_view = gltf.bufferViews[accessor.bufferView] + buffer_data = get_buffer_data(gltf, buffer_view) + + comptype = { + 5120: np.int8, + 5121: np.uint8, + 5122: np.int16, + 5123: np.uint16, + 5125: np.uint32, + 5126: np.float32, + } + dtype = comptype[accessor.componentType] + + t2n = { + "SCALAR": 1, + "VEC2": 2, + "VEC3": 3, + "VEC4": 4, + "MAT2": 4, + "MAT3": 9, + "MAT4": 16, + } + num_components = t2n[accessor.type] + + # Calculate the correct slice of data + byte_offset = accessor.byteOffset if accessor.byteOffset else 0 + byte_stride = ( + buffer_view.byteStride + if buffer_view.byteStride + else num_components * np.dtype(dtype).itemsize + ) + count = accessor.count + + # Extract the attribute data + attribute_data = np.zeros((count, num_components), dtype=dtype) + for i in range(count): + start = byte_offset + i * byte_stride + end = start + num_components * np.dtype(dtype).itemsize + attribute_data[i] = np.frombuffer(buffer_data[start:end], dtype=dtype) + + return attribute_data + + +# Function to extract image data +def get_image_data(gltf, image, folder): + if image.uri: + if image.uri.startswith("data:"): + # Data URI + header, encoded = image.uri.split(",", 1) + data = base64.b64decode(encoded) + else: + # External file + fn = image.uri + if not os.path.isabs(fn): + fn = folder + "/" + fn + with open(fn, "rb") as f: + data = f.read() + else: + buffer_view = gltf.bufferViews[image.bufferView] + data = get_buffer_data(gltf, buffer_view) + return data + + +# Function to convert triangle strip to triangles +def convert_triangle_strip_to_triangles(indices): + triangles = [] + for i in range(len(indices) - 2): + if i % 2 == 0: + triangles.append([indices[i], indices[i + 1], indices[i + 2]]) + else: + triangles.append([indices[i], indices[i + 2], indices[i + 1]]) + return np.array(triangles).reshape(-1, 3) + + +# Function to convert triangle fan to triangles +def convert_triangle_fan_to_triangles(indices): + triangles = [] + for i in range(1, len(indices) - 1): + triangles.append([indices[0], indices[i], indices[i + 1]]) + return np.array(triangles).reshape(-1, 3) + + +# Function to get the transformation matrix from a node +def get_node_transform(node): + if node.matrix: + return np.array(node.matrix).reshape(4, 4).T + else: + T = np.eye(4) + if node.translation: + T[:3, 3] = node.translation + if node.rotation: + R_mat = R.from_quat(node.rotation).as_matrix() + T[:3, :3] = R_mat + if node.scale: + S = np.diag(node.scale + [1]) + T = T @ S + return T + + +def get_world_transform(gltf, node_index, parents, world_transforms): + if parents[node_index] == -2: + return world_transforms[node_index] + + node = gltf.nodes[node_index] + if parents[node_index] == -1: + world_transforms[node_index] = get_node_transform(node) + parents[node_index] = -2 + return world_transforms[node_index] + + parent_index = parents[node_index] + parent_transform = get_world_transform( + gltf, parent_index, parents, world_transforms + ) + world_transforms[node_index] = parent_transform @ get_node_transform(node) + parents[node_index] = -2 + return world_transforms[node_index] + + +def LoadGlb(path): + # Load the GLB file using pygltflib + gltf = GLTF2().load(path) + + primitives = [] + images = {} + # Iterate through the meshes in the GLB file + + world_transforms = [np.identity(4) for i in range(len(gltf.nodes))] + parents = [-1 for i in range(len(gltf.nodes))] + for node_index, node in enumerate(gltf.nodes): + for idx in node.children: + parents[idx] = node_index + # for i in range(len(gltf.nodes)): + # get_world_transform(gltf, i, parents, world_transform) + + for node_index, node in enumerate(gltf.nodes): + if node.mesh is not None: + world_transform = get_world_transform( + gltf, node_index, parents, world_transforms + ) + # Iterate through the primitives in the mesh + mesh = gltf.meshes[node.mesh] + for primitive in mesh.primitives: + # Access the attributes of the primitive + attributes = primitive.attributes.__dict__ + mode = ( + primitive.mode if primitive.mode is not None else 4 + ) # Default to TRIANGLES + result = {} + if primitive.indices is not None: + indices = get_attribute_data(gltf, primitive.indices) + if mode == 4: # TRIANGLES + face_indices = indices.reshape(-1, 3) + elif mode == 5: # TRIANGLE_STRIP + face_indices = convert_triangle_strip_to_triangles(indices) + elif mode == 6: # TRIANGLE_FAN + face_indices = convert_triangle_fan_to_triangles(indices) + else: + continue + result["F"] = face_indices + + # Extract vertex positions + if "POSITION" in attributes and attributes["POSITION"] is not None: + positions = get_attribute_data(gltf, attributes["POSITION"]) + # Apply the world transformation to the positions + positions_homogeneous = np.hstack( + [positions, np.ones((positions.shape[0], 1))] + ) + transformed_positions = ( + world_transform @ positions_homogeneous.T + ).T[:, :3] + result["V"] = transformed_positions + + # Extract vertex colors + if "COLOR_0" in attributes and attributes["COLOR_0"] is not None: + colors = get_attribute_data(gltf, attributes["COLOR_0"]) + if colors.shape[-1] > 3: + colors = colors[..., :3] + result["VC"] = colors + + # Extract UVs + if "TEXCOORD_0" in attributes and not attributes["TEXCOORD_0"] is None: + uvs = get_attribute_data(gltf, attributes["TEXCOORD_0"]) + result["UV"] = uvs + + if primitive.material is not None: + material = gltf.materials[primitive.material] + if ( + material.pbrMetallicRoughness is not None + and material.pbrMetallicRoughness.baseColorTexture is not None + ): + texture_index = ( + material.pbrMetallicRoughness.baseColorTexture.index + ) + texture = gltf.textures[texture_index] + image_index = texture.source + if not image_index in images: + image = gltf.images[image_index] + image_data = get_image_data( + gltf, image, os.path.dirname(path) + ) + pil_image = PILImage.open(io.BytesIO(image_data)) + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + images[image_index] = pil_image + result["TEX"] = image_index + elif material.emissiveTexture is not None: + texture_index = material.emissiveTexture.index + texture = gltf.textures[texture_index] + image_index = texture.source + if not image_index in images: + image = gltf.images[image_index] + image_data = get_image_data( + gltf, image, os.path.dirname(path) + ) + pil_image = PILImage.open(io.BytesIO(image_data)) + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + images[image_index] = pil_image + result["TEX"] = image_index + else: + if material.pbrMetallicRoughness is not None: + base_color = material.pbrMetallicRoughness.baseColorFactor + else: + base_color = np.array([0.8, 0.8, 0.8], dtype=np.float32) + result["MC"] = base_color + + primitives.append(result) + + return primitives, images + + +def RotatePrimitives(primitives, transform): + for i in range(len(primitives)): + if "V" in primitives[i]: + primitives[i]["V"] = primitives[i]["V"] @ transform.T + + +if __name__ == "__main__": + path = "data/test.glb" + LoadGlb(path) diff --git a/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_obj.py b/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab7aa2f60a45db69538681e72a0317f76afb587 --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/custom_rasterizer/io_obj.py @@ -0,0 +1,71 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import cv2 +import numpy as np + + +def LoadObj(fn): + lines = [l.strip() for l in open(fn)] + vertices = [] + faces = [] + for l in lines: + words = [w for w in l.split(" ") if w != ""] + if len(words) == 0: + continue + if words[0] == "v": + v = [float(words[i]) for i in range(1, 4)] + vertices.append(v) + elif words[0] == "f": + f = [int(words[i]) - 1 for i in range(1, 4)] + faces.append(f) + + return np.array(vertices).astype("float32"), np.array(faces).astype("int32") + + +def LoadObjWithTexture(fn, tex_fn): + lines = [l.strip() for l in open(fn)] + vertices = [] + vertex_textures = [] + faces = [] + face_textures = [] + for l in lines: + words = [w for w in l.split(" ") if w != ""] + if len(words) == 0: + continue + if words[0] == "v": + v = [float(words[i]) for i in range(1, len(words))] + vertices.append(v) + elif words[0] == "vt": + v = [float(words[i]) for i in range(1, len(words))] + vertex_textures.append(v) + elif words[0] == "f": + f = [] + ft = [] + for i in range(1, len(words)): + t = words[i].split("/") + f.append(int(t[0]) - 1) + ft.append(int(t[1]) - 1) + for i in range(2, len(f)): + faces.append([f[0], f[i - 1], f[i]]) + face_textures.append([ft[0], ft[i - 1], ft[i]]) + + tex_image = cv2.cvtColor(cv2.imread(tex_fn), cv2.COLOR_BGR2RGB) + return ( + np.array(vertices).astype("float32"), + np.array(vertex_textures).astype("float32"), + np.array(faces).astype("int32"), + np.array(face_textures).astype("int32"), + tex_image, + ) diff --git a/step1x3d_texture/custom_rasterizer/custom_rasterizer/render.py b/step1x3d_texture/custom_rasterizer/custom_rasterizer/render.py new file mode 100644 index 0000000000000000000000000000000000000000..9d06b5195ed18d3cb581ace9efc1c2cebed232eb --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/custom_rasterizer/render.py @@ -0,0 +1,32 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import custom_rasterizer_kernel +import torch + + +def rasterize(pos, tri, resolution, clamp_depth=torch.zeros(0), use_depth_prior=0): + assert pos.device == tri.device + findices, barycentric = custom_rasterizer_kernel.rasterize_image( + pos[0], tri, clamp_depth, resolution[1], resolution[0], 1e-6, use_depth_prior + ) + return findices, barycentric + + +def interpolate(col, findices, barycentric, tri): + f = findices - 1 + (findices == 0) + vcol = col[0, tri.long()[f.long()]] + result = barycentric.view(*barycentric.shape, 1) * vcol + result = torch.sum(result, axis=-2) + return result.view(1, *result.shape) diff --git a/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/__init__.py b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1614ff826f65b4720649d343c43fb09dbb6b9fa5 --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/__init__.py @@ -0,0 +1,13 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. diff --git a/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/grid_neighbor.cpp b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/grid_neighbor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65ab321f397ab8642146795af53b0b8d91374607 --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/grid_neighbor.cpp @@ -0,0 +1,574 @@ +#include "rasterizer.h" +#include + +inline int pos2key(float* p, int resolution) { + int x = (p[0] * 0.5 + 0.5) * resolution; + int y = (p[1] * 0.5 + 0.5) * resolution; + int z = (p[2] * 0.5 + 0.5) * resolution; + return (x * resolution + y) * resolution + z; +} + +inline void key2pos(int key, int resolution, float* p) { + int x = key / resolution / resolution; + int y = key / resolution % resolution; + int z = key % resolution; + p[0] = ((x + 0.5) / resolution - 0.5) * 2; + p[1] = ((y + 0.5) / resolution - 0.5) * 2; + p[2] = ((z + 0.5) / resolution - 0.5) * 2; +} + +inline void key2cornerpos(int key, int resolution, float* p) { + int x = key / resolution / resolution; + int y = key / resolution % resolution; + int z = key % resolution; + p[0] = ((x + 0.75) / resolution - 0.5) * 2; + p[1] = ((y + 0.25) / resolution - 0.5) * 2; + p[2] = ((z + 0.75) / resolution - 0.5) * 2; +} + +inline float* pos_ptr(int l, int i, int j, torch::Tensor t) { + float* pdata = t.data_ptr(); + int height = t.size(1); + int width = t.size(2); + return &pdata[((l * height + i) * width + j) * 4]; +} + +struct Grid +{ + std::vector seq2oddcorner; + std::vector seq2evencorner; + std::vector seq2grid; + std::vector seq2normal; + std::vector seq2neighbor; + std::unordered_map grid2seq; + std::vector downsample_seq; + int num_origin_seq; + int resolution; + int stride; +}; + +inline void pos_from_seq(Grid& grid, int seq, float* p) { + auto k = grid.seq2grid[seq]; + key2pos(k, grid.resolution, p); +} + +inline int fetch_seq(Grid& grid, int l, int i, int j, torch::Tensor pdata) { + float* p = pos_ptr(l, i, j, pdata); + if (p[3] == 0) + return -1; + auto key = pos2key(p, grid.resolution); + int seq = grid.grid2seq[key]; + return seq; +} + +inline int fetch_last_seq(Grid& grid, int i, int j, torch::Tensor pdata) { + int num_layers = pdata.size(0); + int l = 0; + int idx = fetch_seq(grid, l, i, j, pdata); + while (l < num_layers - 1) { + l += 1; + int new_idx = fetch_seq(grid, l, i, j, pdata); + if (new_idx == -1) + break; + idx = new_idx; + } + return idx; +} + +inline int fetch_nearest_seq(Grid& grid, int i, int j, int dim, float d, torch::Tensor pdata) { + float p[3]; + float max_dist = 1e10; + int best_idx = -1; + int num_layers = pdata.size(0); + for (int l = 0; l < num_layers; ++l) { + int idx = fetch_seq(grid, l, i, j, pdata); + if (idx == -1) + break; + pos_from_seq(grid, idx, p); + float dist = std::abs(d - p[(dim + 2) % 3]); + if (dist < max_dist) { + max_dist = dist; + best_idx = idx; + } + } + return best_idx; +} + +inline int fetch_nearest_seq_layer(Grid& grid, int i, int j, int dim, float d, torch::Tensor pdata) { + float p[3]; + float max_dist = 1e10; + int best_layer = -1; + int num_layers = pdata.size(0); + for (int l = 0; l < num_layers; ++l) { + int idx = fetch_seq(grid, l, i, j, pdata); + if (idx == -1) + break; + pos_from_seq(grid, idx, p); + float dist = std::abs(d - p[(dim + 2) % 3]); + if (dist < max_dist) { + max_dist = dist; + best_layer = l; + } + } + return best_layer; +} + +void FetchNeighbor(Grid& grid, int seq, float* pos, int dim, int boundary_info, std::vector& view_layer_positions, + int* output_indices) +{ + auto t = view_layer_positions[dim]; + int height = t.size(1); + int width = t.size(2); + int top = 0; + int ci = 0; + int cj = 0; + if (dim == 0) { + ci = (pos[1]/2+0.5)*height; + cj = (pos[0]/2+0.5)*width; + } + else if (dim == 1) { + ci = (pos[1]/2+0.5)*height; + cj = (pos[2]/2+0.5)*width; + } + else { + ci = (-pos[2]/2+0.5)*height; + cj = (pos[0]/2+0.5)*width; + } + int stride = grid.stride; + for (int ni = ci + stride; ni >= ci - stride; ni -= stride) { + for (int nj = cj - stride; nj <= cj + stride; nj += stride) { + int idx = -1; + if (ni == ci && nj == cj) + idx = seq; + else if (!(ni < 0 || ni >= height || nj < 0 || nj >= width)) { + if (boundary_info == -1) + idx = fetch_seq(grid, 0, ni, nj, t); + else if (boundary_info == 1) + idx = fetch_last_seq(grid, ni, nj, t); + else + idx = fetch_nearest_seq(grid, ni, nj, dim, pos[(dim + 2) % 3], t); + } + output_indices[top] = idx; + top += 1; + } + } +} + +void DownsampleGrid(Grid& src, Grid& tar) +{ + src.downsample_seq.resize(src.seq2grid.size(), -1); + tar.resolution = src.resolution / 2; + tar.stride = src.stride * 2; + float pos[3]; + std::vector seq2normal_count; + for (int i = 0; i < src.seq2grid.size(); ++i) { + key2pos(src.seq2grid[i], src.resolution, pos); + int k = pos2key(pos, tar.resolution); + int s = seq2normal_count.size(); + if (!tar.grid2seq.count(k)) { + tar.grid2seq[k] = tar.seq2grid.size(); + tar.seq2grid.emplace_back(k); + seq2normal_count.emplace_back(0); + seq2normal_count.emplace_back(0); + seq2normal_count.emplace_back(0); + //tar.seq2normal.emplace_back(src.seq2normal[i]); + } else { + s = tar.grid2seq[k] * 3; + } + seq2normal_count[s + src.seq2normal[i]] += 1; + src.downsample_seq[i] = tar.grid2seq[k]; + } + tar.seq2normal.resize(seq2normal_count.size() / 3); + for (int i = 0; i < seq2normal_count.size(); i += 3) { + int t = 0; + for (int j = 1; j < 3; ++j) { + if (seq2normal_count[i + j] > seq2normal_count[i + t]) + t = j; + } + tar.seq2normal[i / 3] = t; + } +} + +void NeighborGrid(Grid& grid, std::vector view_layer_positions, int v) +{ + grid.seq2evencorner.resize(grid.seq2grid.size(), 0); + grid.seq2oddcorner.resize(grid.seq2grid.size(), 0); + std::unordered_set visited_seq; + for (int vd = 0; vd < 3; ++vd) { + auto t = view_layer_positions[vd]; + auto t0 = view_layer_positions[v]; + int height = t.size(1); + int width = t.size(2); + int num_layers = t.size(0); + int num_view_layers = t0.size(0); + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + for (int l = 0; l < num_layers; ++l) { + int seq = fetch_seq(grid, l, i, j, t); + if (seq == -1) + break; + int dim = grid.seq2normal[seq]; + if (dim != v) + continue; + + float pos[3]; + pos_from_seq(grid, seq, pos); + + int ci = 0; + int cj = 0; + if (dim == 0) { + ci = (pos[1]/2+0.5)*height; + cj = (pos[0]/2+0.5)*width; + } + else if (dim == 1) { + ci = (pos[1]/2+0.5)*height; + cj = (pos[2]/2+0.5)*width; + } + else { + ci = (-pos[2]/2+0.5)*height; + cj = (pos[0]/2+0.5)*width; + } + + if ((ci % (grid.stride * 2) < grid.stride) && (cj % (grid.stride * 2) >= grid.stride)) + grid.seq2evencorner[seq] = 1; + + if ((ci % (grid.stride * 2) >= grid.stride) && (cj % (grid.stride * 2) < grid.stride)) + grid.seq2oddcorner[seq] = 1; + + bool is_boundary = false; + if (vd == v) { + if (l == 0 || l == num_layers - 1) + is_boundary = true; + else { + int seq_new = fetch_seq(grid, l + 1, i, j, t); + if (seq_new == -1) + is_boundary = true; + } + } + int boundary_info = 0; + if (is_boundary && (l == 0)) + boundary_info = -1; + else if (is_boundary) + boundary_info = 1; + if (visited_seq.count(seq)) + continue; + visited_seq.insert(seq); + + FetchNeighbor(grid, seq, pos, dim, boundary_info, view_layer_positions, &grid.seq2neighbor[seq * 9]); + } + } + } + } +} + +void PadGrid(Grid& src, Grid& tar, std::vector& view_layer_positions) { + auto& downsample_seq = src.downsample_seq; + auto& seq2evencorner = src.seq2evencorner; + auto& seq2oddcorner = src.seq2oddcorner; + int indices[9]; + std::vector mapped_even_corners(tar.seq2grid.size(), 0); + std::vector mapped_odd_corners(tar.seq2grid.size(), 0); + for (int i = 0; i < downsample_seq.size(); ++i) { + if (seq2evencorner[i] > 0) { + mapped_even_corners[downsample_seq[i]] = 1; + } + if (seq2oddcorner[i] > 0) { + mapped_odd_corners[downsample_seq[i]] = 1; + } + } + auto& tar_seq2normal = tar.seq2normal; + auto& tar_seq2grid = tar.seq2grid; + for (int i = 0; i < tar_seq2grid.size(); ++i) { + if (mapped_even_corners[i] == 1 && mapped_odd_corners[i] == 1) + continue; + auto k = tar_seq2grid[i]; + float p[3]; + key2cornerpos(k, tar.resolution, p); + + int src_key = pos2key(p, src.resolution); + if (!src.grid2seq.count(src_key)) { + int seq = src.seq2grid.size(); + src.grid2seq[src_key] = seq; + src.seq2evencorner.emplace_back((mapped_even_corners[i] == 0)); + src.seq2oddcorner.emplace_back((mapped_odd_corners[i] == 0)); + src.seq2grid.emplace_back(src_key); + src.seq2normal.emplace_back(tar_seq2normal[i]); + FetchNeighbor(src, seq, p, tar_seq2normal[i], 0, view_layer_positions, indices); + for (int j = 0; j < 9; ++j) { + src.seq2neighbor.emplace_back(indices[j]); + } + src.downsample_seq.emplace_back(i); + } else { + int seq = src.grid2seq[src_key]; + if (mapped_even_corners[i] == 0) + src.seq2evencorner[seq] = 1; + if (mapped_odd_corners[i] == 0) + src.seq2oddcorner[seq] = 1; + } + } +} + +std::vector> build_hierarchy(std::vector view_layer_positions, + std::vector view_layer_normals, int num_level, int resolution) +{ + if (view_layer_positions.size() != 3 || num_level < 1) { + printf("Alert! We require 3 layers and at least 1 level! (%d %d)\n", view_layer_positions.size(), num_level); + return {{},{},{},{}}; + } + + std::vector grids; + grids.resize(num_level); + + std::vector seq2pos; + auto& seq2grid = grids[0].seq2grid; + auto& seq2normal = grids[0].seq2normal; + auto& grid2seq = grids[0].grid2seq; + grids[0].resolution = resolution; + grids[0].stride = 1; + + auto int64_options = torch::TensorOptions().dtype(torch::kInt64).requires_grad(false); + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + + for (int v = 0; v < 3; ++v) { + int num_layers = view_layer_positions[v].size(0); + int height = view_layer_positions[v].size(1); + int width = view_layer_positions[v].size(2); + float* data = view_layer_positions[v].data_ptr(); + float* data_normal = view_layer_normals[v].data_ptr(); + for (int l = 0; l < num_layers; ++l) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + float* p = &data[(i * width + j) * 4]; + float* n = &data_normal[(i * width + j) * 3]; + if (p[3] == 0) + continue; + auto k = pos2key(p, resolution); + if (!grid2seq.count(k)) { + int dim = 0; + for (int d = 0; d < 3; ++d) { + if (std::abs(n[d]) > std::abs(n[dim])) + dim = d; + } + dim = (dim + 1) % 3; + grid2seq[k] = seq2grid.size(); + seq2grid.emplace_back(k); + seq2pos.push_back(p[0]); + seq2pos.push_back(p[1]); + seq2pos.push_back(p[2]); + seq2normal.emplace_back(dim); + } + } + } + data += (height * width * 4); + data_normal += (height * width * 3); + } + } + + for (int i = 0; i < num_level - 1; ++i) { + DownsampleGrid(grids[i], grids[i + 1]); + } + + for (int l = 0; l < num_level; ++l) { + grids[l].seq2neighbor.resize(grids[l].seq2grid.size() * 9, -1); + grids[l].num_origin_seq = grids[l].seq2grid.size(); + for (int d = 0; d < 3; ++d) { + NeighborGrid(grids[l], view_layer_positions, d); + } + } + + for (int i = num_level - 2; i >= 0; --i) { + PadGrid(grids[i], grids[i + 1], view_layer_positions); + } + for (int i = grids[0].num_origin_seq; i < grids[0].seq2grid.size(); ++i) { + int k = grids[0].seq2grid[i]; + float p[3]; + key2pos(k, grids[0].resolution, p); + seq2pos.push_back(p[0]); + seq2pos.push_back(p[1]); + seq2pos.push_back(p[2]); + } + + std::vector texture_positions(2); + std::vector grid_neighbors(grids.size()); + std::vector grid_downsamples(grids.size() - 1); + std::vector grid_evencorners(grids.size()); + std::vector grid_oddcorners(grids.size()); + + texture_positions[0] = torch::zeros({static_cast(seq2pos.size() / 3), static_cast(3)}, float_options); + texture_positions[1] = torch::zeros({static_cast(seq2pos.size() / 3)}, float_options); + float* positions_out_ptr = texture_positions[0].data_ptr(); + memcpy(positions_out_ptr, seq2pos.data(), sizeof(float) * seq2pos.size()); + positions_out_ptr = texture_positions[1].data_ptr(); + for (int i = 0; i < grids[0].seq2grid.size(); ++i) { + positions_out_ptr[i] = (i < grids[0].num_origin_seq); + } + + for (int i = 0; i < grids.size(); ++i) { + grid_neighbors[i] = torch::zeros({static_cast(grids[i].seq2grid.size()), static_cast(9)}, int64_options); + int64_t* nptr = grid_neighbors[i].data_ptr(); + for (int j = 0; j < grids[i].seq2neighbor.size(); ++j) { + nptr[j] = grids[i].seq2neighbor[j]; + } + + grid_evencorners[i] = torch::zeros({static_cast(grids[i].seq2evencorner.size())}, int64_options); + grid_oddcorners[i] = torch::zeros({static_cast(grids[i].seq2oddcorner.size())}, int64_options); + int64_t* dptr = grid_evencorners[i].data_ptr(); + for (int j = 0; j < grids[i].seq2evencorner.size(); ++j) { + dptr[j] = grids[i].seq2evencorner[j]; + } + dptr = grid_oddcorners[i].data_ptr(); + for (int j = 0; j < grids[i].seq2oddcorner.size(); ++j) { + dptr[j] = grids[i].seq2oddcorner[j]; + } + if (i + 1 < grids.size()) { + grid_downsamples[i] = torch::zeros({static_cast(grids[i].downsample_seq.size())}, int64_options); + int64_t* dptr = grid_downsamples[i].data_ptr(); + for (int j = 0; j < grids[i].downsample_seq.size(); ++j) { + dptr[j] = grids[i].downsample_seq[j]; + } + } + + } + return {texture_positions, grid_neighbors, grid_downsamples, grid_evencorners, grid_oddcorners}; +} + +std::vector> build_hierarchy_with_feat( + std::vector view_layer_positions, + std::vector view_layer_normals, + std::vector view_layer_feats, + int num_level, int resolution) +{ + if (view_layer_positions.size() != 3 || num_level < 1) { + printf("Alert! We require 3 layers and at least 1 level! (%d %d)\n", view_layer_positions.size(), num_level); + return {{},{},{},{}}; + } + + std::vector grids; + grids.resize(num_level); + + std::vector seq2pos; + std::vector seq2feat; + auto& seq2grid = grids[0].seq2grid; + auto& seq2normal = grids[0].seq2normal; + auto& grid2seq = grids[0].grid2seq; + grids[0].resolution = resolution; + grids[0].stride = 1; + + auto int64_options = torch::TensorOptions().dtype(torch::kInt64).requires_grad(false); + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + + int feat_channel = 3; + for (int v = 0; v < 3; ++v) { + int num_layers = view_layer_positions[v].size(0); + int height = view_layer_positions[v].size(1); + int width = view_layer_positions[v].size(2); + float* data = view_layer_positions[v].data_ptr(); + float* data_normal = view_layer_normals[v].data_ptr(); + float* data_feat = view_layer_feats[v].data_ptr(); + feat_channel = view_layer_feats[v].size(3); + for (int l = 0; l < num_layers; ++l) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + float* p = &data[(i * width + j) * 4]; + float* n = &data_normal[(i * width + j) * 3]; + float* f = &data_feat[(i * width + j) * feat_channel]; + if (p[3] == 0) + continue; + auto k = pos2key(p, resolution); + if (!grid2seq.count(k)) { + int dim = 0; + for (int d = 0; d < 3; ++d) { + if (std::abs(n[d]) > std::abs(n[dim])) + dim = d; + } + dim = (dim + 1) % 3; + grid2seq[k] = seq2grid.size(); + seq2grid.emplace_back(k); + seq2pos.push_back(p[0]); + seq2pos.push_back(p[1]); + seq2pos.push_back(p[2]); + for (int c = 0; c < feat_channel; ++c) { + seq2feat.emplace_back(f[c]); + } + seq2normal.emplace_back(dim); + } + } + } + data += (height * width * 4); + data_normal += (height * width * 3); + data_feat += (height * width * feat_channel); + } + } + + for (int i = 0; i < num_level - 1; ++i) { + DownsampleGrid(grids[i], grids[i + 1]); + } + + for (int l = 0; l < num_level; ++l) { + grids[l].seq2neighbor.resize(grids[l].seq2grid.size() * 9, -1); + grids[l].num_origin_seq = grids[l].seq2grid.size(); + for (int d = 0; d < 3; ++d) { + NeighborGrid(grids[l], view_layer_positions, d); + } + } + + for (int i = num_level - 2; i >= 0; --i) { + PadGrid(grids[i], grids[i + 1], view_layer_positions); + } + for (int i = grids[0].num_origin_seq; i < grids[0].seq2grid.size(); ++i) { + int k = grids[0].seq2grid[i]; + float p[3]; + key2pos(k, grids[0].resolution, p); + seq2pos.push_back(p[0]); + seq2pos.push_back(p[1]); + seq2pos.push_back(p[2]); + for (int c = 0; c < feat_channel; ++c) { + seq2feat.emplace_back(0.5); + } + } + + std::vector texture_positions(2); + std::vector texture_feats(1); + std::vector grid_neighbors(grids.size()); + std::vector grid_downsamples(grids.size() - 1); + std::vector grid_evencorners(grids.size()); + std::vector grid_oddcorners(grids.size()); + + texture_positions[0] = torch::zeros({static_cast(seq2pos.size() / 3), static_cast(3)}, float_options); + texture_positions[1] = torch::zeros({static_cast(seq2pos.size() / 3)}, float_options); + texture_feats[0] = torch::zeros({static_cast(seq2feat.size() / feat_channel), static_cast(feat_channel)}, float_options); + float* positions_out_ptr = texture_positions[0].data_ptr(); + memcpy(positions_out_ptr, seq2pos.data(), sizeof(float) * seq2pos.size()); + positions_out_ptr = texture_positions[1].data_ptr(); + for (int i = 0; i < grids[0].seq2grid.size(); ++i) { + positions_out_ptr[i] = (i < grids[0].num_origin_seq); + } + float* feats_out_ptr = texture_feats[0].data_ptr(); + memcpy(feats_out_ptr, seq2feat.data(), sizeof(float) * seq2feat.size()); + + for (int i = 0; i < grids.size(); ++i) { + grid_neighbors[i] = torch::zeros({static_cast(grids[i].seq2grid.size()), static_cast(9)}, int64_options); + int64_t* nptr = grid_neighbors[i].data_ptr(); + for (int j = 0; j < grids[i].seq2neighbor.size(); ++j) { + nptr[j] = grids[i].seq2neighbor[j]; + } + grid_evencorners[i] = torch::zeros({static_cast(grids[i].seq2evencorner.size())}, int64_options); + grid_oddcorners[i] = torch::zeros({static_cast(grids[i].seq2oddcorner.size())}, int64_options); + int64_t* dptr = grid_evencorners[i].data_ptr(); + for (int j = 0; j < grids[i].seq2evencorner.size(); ++j) { + dptr[j] = grids[i].seq2evencorner[j]; + } + dptr = grid_oddcorners[i].data_ptr(); + for (int j = 0; j < grids[i].seq2oddcorner.size(); ++j) { + dptr[j] = grids[i].seq2oddcorner[j]; + } + if (i + 1 < grids.size()) { + grid_downsamples[i] = torch::zeros({static_cast(grids[i].downsample_seq.size())}, int64_options); + int64_t* dptr = grid_downsamples[i].data_ptr(); + for (int j = 0; j < grids[i].downsample_seq.size(); ++j) { + dptr[j] = grids[i].downsample_seq[j]; + } + } + } + return {texture_positions, texture_feats, grid_neighbors, grid_downsamples, grid_evencorners, grid_oddcorners}; +} diff --git a/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.cpp b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4af6eebfeb44f1578ae99f07601f2e2fa1c3c0d8 --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.cpp @@ -0,0 +1,139 @@ +#include "rasterizer.h" + +void rasterizeTriangleCPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) { + float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0])); + float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0])); + float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1])); + float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1])); + + for (int px = x_min; px < x_max + 1; ++px) { + if (px < 0 || px >= width) + continue; + for (int py = y_min; py < y_max + 1; ++py) { + if (py < 0 || py >= height) + continue; + float vt[2] = {px + 0.5, py + 0.5}; + float baryCentricCoordinate[3]; + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate); + if (isBarycentricCoordInBounds(baryCentricCoordinate)) { + int pixel = py * width + px; + if (zbuffer == 0) { + zbuffer[pixel] = (INT64)(idx + 1); + continue; + } + + float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2]; + float depth_thres = 0; + if (d) { + depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation; + } + + int z_quantize = depth * (2<<17); + INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1); + if (depth < depth_thres) + continue; + zbuffer[pixel] = std::min(zbuffer[pixel], token); + } + } + } +} + +void barycentricFromImgcoordCPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces, + float* barycentric_map, int pix) +{ + INT64 f = zbuffer[pix] % MAXINT; + if (f == (MAXINT-1)) { + findices[pix] = 0; + barycentric_map[pix * 3] = 0; + barycentric_map[pix * 3 + 1] = 0; + barycentric_map[pix * 3 + 2] = 0; + return; + } + findices[pix] = f; + f -= 1; + float barycentric[3] = {0, 0, 0}; + if (f >= 0) { + float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f}; + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f}; + float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f}; + float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f}; + + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric); + + barycentric[0] = barycentric[0] / vt0_ptr[3]; + barycentric[1] = barycentric[1] / vt1_ptr[3]; + barycentric[2] = barycentric[2] / vt2_ptr[3]; + float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]); + barycentric[0] *= w; + barycentric[1] *= w; + barycentric[2] *= w; + + } + barycentric_map[pix * 3] = barycentric[0]; + barycentric_map[pix * 3 + 1] = barycentric[1]; + barycentric_map[pix * 3 + 2] = barycentric[2]; +} + +void rasterizeImagecoordsKernelCPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces, int f) +{ + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f}; + float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f}; + float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f}; + + rasterizeTriangleCPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc); +} + +std::vector rasterize_image_cpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int num_faces = F.size(0); + int num_vertices = V.size(0); + auto options = torch::TensorOptions().dtype(torch::kInt32).requires_grad(false); + auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).requires_grad(false); + auto findices = torch::zeros({height, width}, options); + INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1); + auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint; + + if (!use_depth_prior) { + for (int i = 0; i < num_faces; ++i) { + rasterizeImagecoordsKernelCPU(V.data_ptr(), F.data_ptr(), 0, + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces, i); + } + } else { + for (int i = 0; i < num_faces; ++i) + rasterizeImagecoordsKernelCPU(V.data_ptr(), F.data_ptr(), D.data_ptr(), + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces, i); + } + + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + auto barycentric = torch::zeros({height, width, 3}, float_options); + for (int i = 0; i < width * height; ++i) + barycentricFromImgcoordCPU(V.data_ptr(), F.data_ptr(), + findices.data_ptr(), (INT64*)z_min.data_ptr(), width, height, num_vertices, num_faces, barycentric.data_ptr(), i); + + return {findices, barycentric}; +} + +std::vector rasterize_image(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int device_id = V.get_device(); + if (device_id == -1) + return rasterize_image_cpu(V, F, D, width, height, occlusion_truncation, use_depth_prior); + else + return rasterize_image_gpu(V, F, D, width, height, occlusion_truncation, use_depth_prior); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rasterize_image", &rasterize_image, "Custom image rasterization"); + m.def("build_hierarchy", &build_hierarchy, "Custom image rasterization"); + m.def("build_hierarchy_with_feat", &build_hierarchy_with_feat, "Custom image rasterization"); +} diff --git a/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.h b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.h new file mode 100644 index 0000000000000000000000000000000000000000..cf4f9870bda0714763e4236f85293ca7cef7d51f --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer.h @@ -0,0 +1,54 @@ +#ifndef RASTERIZER_H_ +#define RASTERIZER_H_ + +#include +#include +#include +#include // For CUDA context + +#define INT64 unsigned long long +#define MAXINT 2147483647 + +__host__ __device__ inline float calculateSignedArea2(float* a, float* b, float* c) { + return ((c[0] - a[0]) * (b[1] - a[1]) - (b[0] - a[0]) * (c[1] - a[1])); +} + +__host__ __device__ inline void calculateBarycentricCoordinate(float* a, float* b, float* c, float* p, + float* barycentric) +{ + float beta_tri = calculateSignedArea2(a, p, c); + float gamma_tri = calculateSignedArea2(a, b, p); + float area = calculateSignedArea2(a, b, c); + if (area == 0) { + barycentric[0] = -1.0; + barycentric[1] = -1.0; + barycentric[2] = -1.0; + return; + } + float tri_inv = 1.0 / area; + float beta = beta_tri * tri_inv; + float gamma = gamma_tri * tri_inv; + float alpha = 1.0 - beta - gamma; + barycentric[0] = alpha; + barycentric[1] = beta; + barycentric[2] = gamma; +} + +__host__ __device__ inline bool isBarycentricCoordInBounds(float* barycentricCoord) { + return barycentricCoord[0] >= 0.0 && barycentricCoord[0] <= 1.0 && + barycentricCoord[1] >= 0.0 && barycentricCoord[1] <= 1.0 && + barycentricCoord[2] >= 0.0 && barycentricCoord[2] <= 1.0; +} + +std::vector rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior); + +std::vector> build_hierarchy(std::vector view_layer_positions, std::vector view_layer_normals, int num_level, int resolution); + +std::vector> build_hierarchy_with_feat( + std::vector view_layer_positions, + std::vector view_layer_normals, + std::vector view_layer_feats, + int num_level, int resolution); + +#endif \ No newline at end of file diff --git a/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer_gpu.cu b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..cc6f354c0e2801b9ac84ec4547845c8edb606a60 --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/lib/custom_rasterizer_kernel/rasterizer_gpu.cu @@ -0,0 +1,127 @@ +#include "rasterizer.h" + +__device__ void rasterizeTriangleGPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) { + float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0])); + float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0])); + float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1])); + float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1])); + + for (int px = x_min; px < x_max + 1; ++px) { + if (px < 0 || px >= width) + continue; + for (int py = y_min; py < y_max + 1; ++py) { + if (py < 0 || py >= height) + continue; + float vt[2] = {px + 0.5f, py + 0.5f}; + float baryCentricCoordinate[3]; + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate); + if (isBarycentricCoordInBounds(baryCentricCoordinate)) { + int pixel = py * width + px; + if (zbuffer == 0) { + atomicExch(&zbuffer[pixel], (INT64)(idx + 1)); + continue; + } + float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2]; + float depth_thres = 0; + if (d) { + depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation; + } + + int z_quantize = depth * (2<<17); + INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1); + if (depth < depth_thres) + continue; + atomicMin(&zbuffer[pixel], token); + } + } + } +} + +__global__ void barycentricFromImgcoordGPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces, + float* barycentric_map) +{ + int pix = blockIdx.x * blockDim.x + threadIdx.x; + if (pix >= width * height) + return; + INT64 f = zbuffer[pix] % MAXINT; + if (f == (MAXINT-1)) { + findices[pix] = 0; + barycentric_map[pix * 3] = 0; + barycentric_map[pix * 3 + 1] = 0; + barycentric_map[pix * 3 + 2] = 0; + return; + } + findices[pix] = f; + f -= 1; + float barycentric[3] = {0, 0, 0}; + if (f >= 0) { + float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f}; + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f}; + float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f}; + float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f}; + + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric); + + barycentric[0] = barycentric[0] / vt0_ptr[3]; + barycentric[1] = barycentric[1] / vt1_ptr[3]; + barycentric[2] = barycentric[2] / vt2_ptr[3]; + float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]); + barycentric[0] *= w; + barycentric[1] *= w; + barycentric[2] *= w; + + } + barycentric_map[pix * 3] = barycentric[0]; + barycentric_map[pix * 3 + 1] = barycentric[1]; + barycentric_map[pix * 3 + 2] = barycentric[2]; +} + +__global__ void rasterizeImagecoordsKernelGPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces) +{ + int f = blockIdx.x * blockDim.x + threadIdx.x; + if (f >= num_faces) + return; + + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f}; + float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f}; + float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f}; + + rasterizeTriangleGPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc); +} + +std::vector rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int device_id = V.get_device(); + cudaSetDevice(device_id); + int num_faces = F.size(0); + int num_vertices = V.size(0); + auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_id).requires_grad(false); + auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA, device_id).requires_grad(false); + auto findices = torch::zeros({height, width}, options); + INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1); + auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint; + + if (!use_depth_prior) { + rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr(), F.data_ptr(), 0, + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces); + } else { + rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr(), F.data_ptr(), D.data_ptr(), + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces); + } + + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_id).requires_grad(false); + auto barycentric = torch::zeros({height, width, 3}, float_options); + barycentricFromImgcoordGPU<<<(width * height + 255)/256, 256>>>(V.data_ptr(), F.data_ptr(), + findices.data_ptr(), (INT64*)z_min.data_ptr(), width, height, num_vertices, num_faces, barycentric.data_ptr()); + + return {findices, barycentric}; +} diff --git a/step1x3d_texture/custom_rasterizer/setup.py b/step1x3d_texture/custom_rasterizer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..92a00d65979938b711e653a7b4e0cdef07a9ff0e --- /dev/null +++ b/step1x3d_texture/custom_rasterizer/setup.py @@ -0,0 +1,27 @@ +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# build custom rasterizer +# build with `python setup.py install` +# nvcc is needed + +custom_rasterizer_module = CUDAExtension( + "custom_rasterizer_kernel", + [ + "lib/custom_rasterizer_kernel/rasterizer.cpp", + "lib/custom_rasterizer_kernel/grid_neighbor.cpp", + "lib/custom_rasterizer_kernel/rasterizer_gpu.cu", + ], +) + +setup( + packages=find_packages(), + version="0.1", + name="custom_rasterizer", + include_package_data=True, + package_dir={"": "."}, + ext_modules=[ + custom_rasterizer_module, + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/step1x3d_texture/data/multiview.py b/step1x3d_texture/data/multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..aff8e5bf2bf613e830e4775e92496fa965d1c3fa --- /dev/null +++ b/step1x3d_texture/data/multiview.py @@ -0,0 +1,540 @@ +import json +import os +import random +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from ..utils.config import parse_structured +from ..utils.geometry import ( + get_plucker_embeds_from_cameras, + get_plucker_embeds_from_cameras_ortho, + get_position_map_from_depth, + get_position_map_from_depth_ortho, +) +from ..utils.typing import * + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + + +def _parse_scene_list_single(scene_list_path: str, root_data_dir: str): + all_scenes = [] + if scene_list_path.endswith(".json"): + with open(scene_list_path) as f: + for p in json.loads(f.read()): + if "/" in p: + all_scenes.append(os.path.join(root_data_dir, p)) + else: + all_scenes.append(os.path.join(root_data_dir, p[:2], p)) + elif scene_list_path.endswith(".txt"): + with open(scene_list_path) as f: + for p in f.readlines(): + p = p.strip() + if "/" in p: + all_scenes.append(os.path.join(root_data_dir, p)) + else: + all_scenes.append(os.path.join(root_data_dir, p[:2], p)) + else: + raise NotImplementedError + + return all_scenes + + +def _parse_scene_list( + scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]] +): + all_scenes = [] + if isinstance(scene_list_path, str): + scene_list_path = [scene_list_path] + if isinstance(root_data_dir, str): + root_data_dir = [root_data_dir] + for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir): + all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_) + return all_scenes + + +def _parse_reference_scene_list(reference_scenes: List[str], all_scenes: List[str]): + all_ids = set(scene.split("/")[-1] for scene in all_scenes) + ref_ids = set(scene.split("/")[-1] for scene in reference_scenes) + common_ids = ref_ids.intersection(all_ids) + all_scenes = [scene for scene in all_scenes if scene.split("/")[-1] in common_ids] + all_ids = {scene.split("/")[-1]: idx for idx, scene in enumerate(all_scenes)} + + ref_scenes = [ + scene for scene in reference_scenes if scene.split("/")[-1] in all_ids + ] + sorted_ref_scenes = sorted(ref_scenes, key=lambda x: all_ids[x.split("/")[-1]]) + scene2ref = { + scene: ref_scene for scene, ref_scene in zip(all_scenes, sorted_ref_scenes) + } + + return all_scenes, scene2ref + + +@dataclass +class MultiviewDataModuleConfig: + root_dir: Any = "" + scene_list: Any = "" + image_suffix: str = "webp" + background_color: Union[str, float] = "gray" + image_names: List[str] = field(default_factory=lambda: []) + image_modality: str = "render" + num_views: int = 1 + random_view_list: Optional[List[List[int]]] = None + + prompt_db_path: Optional[str] = None + return_prompt: bool = False + use_empty_prompt: bool = False + prompt_prefix: Optional[Any] = None + return_one_prompt: bool = True + + projection_type: str = "ORTHO" + + # source conditions + source_image_modality: Any = "position" + use_camera_space_normal: bool = False + position_offset: float = 0.5 + position_scale: float = 1.0 + plucker_offset: float = 1.0 + plucker_scale: float = 2.0 + + # reference image + reference_root_dir: Optional[Any] = None + reference_scene_list: Optional[Any] = None + reference_image_modality: str = "render" + reference_image_names: List[str] = field(default_factory=lambda: []) + reference_augment_resolutions: Optional[List[int]] = None + reference_mask_aug: bool = False + + repeat: int = 1 # for debugging purpose + + train_indices: Optional[Tuple[Any, Any]] = None + val_indices: Optional[Tuple[Any, Any]] = None + test_indices: Optional[Tuple[Any, Any]] = None + + height: int = 768 + width: int = 768 + + batch_size: int = 1 + eval_batch_size: int = 1 + + num_workers: int = 16 + + +class MultiviewDataset(Dataset): + def __init__(self, cfg: Any, split: str = "train") -> None: + super().__init__() + assert split in ["train", "val", "test"] + self.cfg: MultiviewDataModuleConfig = cfg + self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.root_dir) + + if ( + self.cfg.reference_root_dir is not None + and self.cfg.reference_scene_list is not None + ): + reference_scenes = _parse_scene_list( + self.cfg.reference_scene_list, self.cfg.reference_root_dir + ) + self.all_scenes, self.reference_scenes = _parse_reference_scene_list( + reference_scenes, self.all_scenes + ) + else: + self.reference_scenes = None + + self.split = split + if self.split == "train" and self.cfg.train_indices is not None: + self.all_scenes = self.all_scenes[ + self.cfg.train_indices[0] : self.cfg.train_indices[1] + ] + self.all_scenes = self.all_scenes * self.cfg.repeat + elif self.split == "val" and self.cfg.val_indices is not None: + self.all_scenes = self.all_scenes[ + self.cfg.val_indices[0] : self.cfg.val_indices[1] + ] + elif self.split == "test" and self.cfg.test_indices is not None: + self.all_scenes = self.all_scenes[ + self.cfg.test_indices[0] : self.cfg.test_indices[1] + ] + + if self.cfg.prompt_db_path is not None: + self.prompt_db = json.load(open(self.cfg.prompt_db_path)) + else: + self.prompt_db = None + + def __len__(self): + return len(self.all_scenes) + + def get_bg_color(self, bg_color): + if bg_color == "white": + bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32) + elif bg_color == "black": + bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32) + elif bg_color == "gray": + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif bg_color == "random": + bg_color = np.random.rand(3) + elif bg_color == "random_gray": + bg_color = random.uniform(0.3, 0.7) + bg_color = np.array([bg_color] * 3, dtype=np.float32) + elif isinstance(bg_color, float): + bg_color = np.array([bg_color] * 3, dtype=np.float32) + elif isinstance(bg_color, list) or isinstance(bg_color, tuple): + bg_color = np.array(bg_color, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + + def load_image( + self, + image: Union[str, Image.Image], + height: int, + width: int, + background_color: torch.Tensor, + rescale: bool = False, + mask_aug: bool = False, + ): + if isinstance(image, str): + image = Image.open(image) + + image = image.resize((width, height)) + image = torch.from_numpy(np.array(image)).float() / 255.0 + + if mask_aug: + alpha = image[:, :, 3] # Extract alpha channel + h, w = alpha.shape + y_indices, x_indices = torch.where(alpha > 0.5) + if len(y_indices) > 0 and len(x_indices) > 0: + idx = torch.randint(len(y_indices), (1,)).item() + y_center = y_indices[idx].item() + x_center = x_indices[idx].item() + mask_h = random.randint(h // 8, h // 4) + mask_w = random.randint(w // 8, w // 4) + + y1 = max(0, y_center - mask_h // 2) + y2 = min(h, y_center + mask_h // 2) + x1 = max(0, x_center - mask_w // 2) + x2 = min(w, x_center + mask_w // 2) + + alpha[y1:y2, x1:x2] = 0.0 + image[:, :, 3] = alpha + + image = image[:, :, :3] * image[:, :, 3:4] + background_color * ( + 1 - image[:, :, 3:4] + ) + if rescale: + image = image * 2.0 - 1.0 + return image + + def load_normal_image( + self, + path, + height, + width, + background_color, + camera_space: bool = False, + c2w: Optional[torch.FloatTensor] = None, + ): + image = Image.open(path).resize((width, height), resample=Image.NEAREST) + image = torch.from_numpy(np.array(image)).float() / 255.0 + alpha = image[:, :, 3:4] + image = image[:, :, :3] + if camera_space: + w2c = torch.linalg.inv(c2w)[:3, :3] + image = ( + F.normalize(((image * 2 - 1)[:, :, None, :] * w2c).sum(-1), dim=-1) + * 0.5 + + 0.5 + ) + image = image * alpha + background_color * (1 - alpha) + return image + + def load_depth(self, path, height, width): + depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) + depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST) + depth = torch.from_numpy(depth[..., 0:1]).float() + mask = torch.ones_like(depth) + mask[depth > 1000.0] = 0.0 # depth = 65535 is the invalid value + depth[~(mask > 0.5)] = 0.0 + return depth, mask + + def retrieve_prompt(self, scene_dir): + assert self.prompt_db is not None + source_id = os.path.basename(scene_dir) + return self.prompt_db.get(source_id, "") + + def __getitem__(self, index): + background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) + scene_dir = self.all_scenes[index] + + with open(os.path.join(scene_dir, "meta.json")) as f: + meta = json.load(f) + name2loc = {loc["index"]: loc for loc in meta["locations"]} + + # target multi-view images + image_paths = [ + os.path.join( + scene_dir, f"{self.cfg.image_modality}_{f}.{self.cfg.image_suffix}" + ) + for f in self.cfg.image_names + ] + images = [ + self.load_image( + p, + height=self.cfg.height, + width=self.cfg.width, + background_color=background_color, + ) + for p in image_paths + ] + images = torch.stack(images, dim=0).permute(0, 3, 1, 2) + + # camera + c2w = [ + torch.as_tensor(name2loc[name]["transform_matrix"]) + for name in self.cfg.image_names + ] + c2w = torch.stack(c2w, dim=0) + + if self.cfg.projection_type == "PERSP": + camera_angle_x = ( + meta.get("camera_angle_x", None) + or meta["locations"][0]["camera_angle_x"] + ) + focal_length = 0.5 * self.cfg.width / np.tan(0.5 * camera_angle_x) + intrinsics = ( + torch.as_tensor( + [ + [focal_length, 0.0, 0.5 * self.cfg.width], + [0.0, focal_length, 0.5 * self.cfg.height], + [0.0, 0.0, 1.0], + ] + ) + .unsqueeze(0) + .float() + .repeat(len(self.cfg.image_names), 1, 1) + ) + elif self.cfg.projection_type == "ORTHO": + ortho_scale = ( + meta.get("ortho_scale", None) or meta["locations"][0]["ortho_scale"] + ) + + # source conditions + source_image_modality = self.cfg.source_image_modality + if isinstance(source_image_modality, str): + source_image_modality = [source_image_modality] + source_images = [] + for modality in source_image_modality: + if modality == "position": + depth_masks = [ + self.load_depth( + os.path.join(scene_dir, f"depth_{f}.exr"), + self.cfg.height, + self.cfg.width, + ) + for f in self.cfg.image_names + ] + depths = torch.stack([d for d, _ in depth_masks]) + masks = torch.stack([m for _, m in depth_masks]) + c2w_ = c2w.clone() + c2w_[:, :, 1:3] *= -1 + + if self.cfg.projection_type == "PERSP": + position_maps = get_position_map_from_depth( + depths, + masks, + intrinsics, + c2w_, + image_wh=(self.cfg.width, self.cfg.height), + ) + elif self.cfg.projection_type == "ORTHO": + position_maps = get_position_map_from_depth_ortho( + depths, + masks, + c2w_, + ortho_scale, + image_wh=(self.cfg.width, self.cfg.height), + ) + position_maps = ( + (position_maps + self.cfg.position_offset) / self.cfg.position_scale + ).clamp(0.0, 1.0) + source_images.append(position_maps) + elif modality == "normal": + normal_maps = [ + self.load_normal_image( + os.path.join( + scene_dir, f"{modality}_{f}.{self.cfg.image_suffix}" + ), + height=self.cfg.height, + width=self.cfg.width, + background_color=background_color, + camera_space=self.cfg.use_camera_space_normal, + c2w=c, + ) + for c, f in zip(c2w, self.cfg.image_names) + ] + source_images.append(torch.stack(normal_maps, dim=0)) + elif modality == "plucker": + if self.cfg.projection_type == "ORTHO": + plucker_embed = get_plucker_embeds_from_cameras_ortho( + c2w, [ortho_scale] * len(c2w), self.cfg.width + ) + elif self.cfg.projection_type == "PERSP": + plucker_embed = get_plucker_embeds_from_cameras( + c2w, [camera_angle_x] * len(c2w), self.cfg.width + ) + else: + raise NotImplementedError + plucker_embed = plucker_embed.permute(0, 2, 3, 1) + plucker_embed = ( + (plucker_embed + self.cfg.plucker_offset) / self.cfg.plucker_scale + ).clamp(0.0, 1.0) + source_images.append(plucker_embed) + else: + raise NotImplementedError + source_images = torch.cat(source_images, dim=-1).permute(0, 3, 1, 2) + rv = {"rgb": images, "c2w": c2w, "source_rgb": source_images} + + num_images = len(self.cfg.image_names) + # prompt + if self.cfg.return_prompt: + if self.cfg.use_empty_prompt: + prompt = "" + else: + prompt = self.retrieve_prompt(scene_dir) + prompts = [prompt] * num_images + + if self.cfg.prompt_prefix is not None: + prompt_prefix = self.cfg.prompt_prefix + if isinstance(prompt_prefix, str): + prompt_prefix = [prompt_prefix] * num_images + + for i, prompt in enumerate(prompts): + prompts[i] = f"{prompt_prefix[i]} {prompt}" + + if self.cfg.return_one_prompt: + rv.update({"prompts": prompts[0]}) + else: + rv.update({"prompts": prompts}) + + # reference image + if self.reference_scenes is not None: + reference_scene_dir = self.reference_scenes[scene_dir] + reference_image_paths = [ + os.path.join( + reference_scene_dir, + f"{self.cfg.reference_image_modality}_{f}.{self.cfg.image_suffix}", + ) + for f in self.cfg.reference_image_names + ] + reference_image_path = random.choice(reference_image_paths) + + if self.cfg.reference_augment_resolutions is None: + reference_image = self.load_image( + reference_image_path, + height=self.cfg.height, + width=self.cfg.width, + background_color=background_color, + mask_aug=self.cfg.reference_mask_aug, + ).permute(2, 0, 1) + rv.update({"reference_rgb": reference_image}) + else: + random_resolution = random.choice( + self.cfg.reference_augment_resolutions + ) + reference_image_ = Image.open(reference_image_path).resize( + (random_resolution, random_resolution) + ) + reference_image = self.load_image( + reference_image_, + height=self.cfg.height, + width=self.cfg.width, + background_color=background_color, + mask_aug=self.cfg.reference_mask_aug, + ).permute(2, 0, 1) + rv.update({"reference_rgb": reference_image}) + + return rv + + def collate(self, batch): + batch = torch.utils.data.default_collate(batch) + pack = lambda t: t.view(-1, *t.shape[2:]) + + if self.cfg.random_view_list is not None: + indices = random.choice(self.cfg.random_view_list) + else: + indices = list(range(self.cfg.num_views)) + num_views = len(indices) + + for k in batch.keys(): + if k in ["rgb", "source_rgb", "c2w"]: + batch[k] = batch[k][:, indices] + batch[k] = pack(batch[k]) + for k in ["prompts"]: + if not self.cfg.return_one_prompt: + batch[k] = [item for pair in zip(*batch[k]) for item in pair] + + batch.update( + { + "num_views": num_views, + # For SDXL + "original_size": (self.cfg.height, self.cfg.width), + "target_size": (self.cfg.height, self.cfg.width), + "crops_coords_top_left": (0, 0), + } + ) + return batch + + +class MultiviewDataModule(pl.LightningDataModule): + cfg: MultiviewDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(MultiviewDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = MultiviewDataset(self.cfg, "train") + if stage in [None, "fit", "validate"]: + self.val_dataset = MultiviewDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = MultiviewDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.cfg.batch_size, + num_workers=self.cfg.num_workers, + shuffle=True, + collate_fn=self.train_dataset.collate, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.cfg.eval_batch_size, + num_workers=self.cfg.num_workers, + shuffle=False, + collate_fn=self.val_dataset.collate, + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, + batch_size=self.cfg.eval_batch_size, + num_workers=self.cfg.num_workers, + shuffle=False, + collate_fn=self.test_dataset.collate, + ) + + def predict_dataloader(self) -> DataLoader: + return self.test_dataloader() diff --git a/step1x3d_texture/differentiable_renderer/__init__.py b/step1x3d_texture/differentiable_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1614ff826f65b4720649d343c43fb09dbb6b9fa5 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/__init__.py @@ -0,0 +1,13 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. diff --git a/step1x3d_texture/differentiable_renderer/camera_utils.py b/step1x3d_texture/differentiable_renderer/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..260611f9566e56f0e090fa67445e864b63468c80 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/camera_utils.py @@ -0,0 +1,118 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import math + +import numpy as np +import torch + + +def transform_pos(mtx, pos, keepdim=False): + t_mtx = torch.from_numpy(mtx).to(pos.device) if isinstance(mtx, np.ndarray) else mtx + if pos.shape[-1] == 3: + posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) + else: + posw = pos + + if keepdim: + return torch.matmul(posw, t_mtx.t())[...] + else: + return torch.matmul(posw, t_mtx.t())[None, ...] + + +def get_mv_matrix(elev, azim, camera_distance, center=None): + elev = -elev + azim += 90 + + elev_rad = math.radians(elev) + azim_rad = math.radians(azim) + + camera_position = np.array( + [ + camera_distance * math.cos(elev_rad) * math.cos(azim_rad), + camera_distance * math.cos(elev_rad) * math.sin(azim_rad), + camera_distance * math.sin(elev_rad), + ] + ) + + if center is None: + center = np.array([0, 0, 0]) + else: + center = np.array(center) + + lookat = center - camera_position + lookat = lookat / np.linalg.norm(lookat) + + up = np.array([0, 0, 1.0]) + right = np.cross(lookat, up) + right = right / np.linalg.norm(right) + up = np.cross(right, lookat) + up = up / np.linalg.norm(up) + + c2w = np.concatenate( + [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]], axis=-1 + ) + + w2c = np.zeros((4, 4)) + w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0)) + w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:]) + w2c[3, 3] = 1.0 + + return w2c.astype(np.float32) + + +def get_orthographic_projection_matrix( + left=-1, right=1, bottom=-1, top=1, near=0, far=2 +): + """ + 计算正交投影矩阵。 + + 参数: + left (float): 投影区域左侧边界。 + right (float): 投影区域右侧边界。 + bottom (float): 投影区域底部边界。 + top (float): 投影区域顶部边界。 + near (float): 投影区域近裁剪面距离。 + far (float): 投影区域远裁剪面距离。 + + 返回: + numpy.ndarray: 正交投影矩阵。 + """ + ortho_matrix = np.eye(4, dtype=np.float32) + ortho_matrix[0, 0] = 2 / (right - left) + ortho_matrix[1, 1] = 2 / (top - bottom) + ortho_matrix[2, 2] = -2 / (far - near) + ortho_matrix[0, 3] = -(right + left) / (right - left) + ortho_matrix[1, 3] = -(top + bottom) / (top - bottom) + ortho_matrix[2, 3] = -(far + near) / (far - near) + # ortho_matrix = np.eye(4, dtype=np.float32) + # ortho_matrix[0, 0] = 2 / (right - left) + # ortho_matrix[1, 1] = -2 / (top - bottom) + # ortho_matrix[2, 2] = -2 / (far - near) + # ortho_matrix[0, 3] = -(right + left) / (right - left) + # ortho_matrix[1, 3] = -(top + bottom) / (top - bottom) + # ortho_matrix[2, 3] = -(far + near) / (far - near) + return ortho_matrix + + +def get_perspective_projection_matrix(fovy, aspect_wh, near, far): + fovy_rad = math.radians(fovy) + return np.array( + [ + [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0], + [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0], + [0, 0, -(far + near) / (far - near), -2.0 * far * near / (far - near)], + [0, 0, -1, 0], + ] + ).astype(np.float32) diff --git a/step1x3d_texture/differentiable_renderer/compile_mesh_painter.bat b/step1x3d_texture/differentiable_renderer/compile_mesh_painter.bat new file mode 100644 index 0000000000000000000000000000000000000000..3947b0f03f9f6245dac95db7460703076444a304 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/compile_mesh_painter.bat @@ -0,0 +1,3 @@ +FOR /F "tokens=*" %%i IN ('python -m pybind11 --includes') DO SET PYINCLUDES=%%i +echo %PYINCLUDES% +g++ -O3 -Wall -shared -std=c++11 -fPIC %PYINCLUDES% mesh_processor.cpp -o mesh_processor.pyd -lpython3.12 \ No newline at end of file diff --git a/step1x3d_texture/differentiable_renderer/mesh_processor.cpp b/step1x3d_texture/differentiable_renderer/mesh_processor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca8650fada02099d3fce0f551fa4f953f278cf34 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/mesh_processor.cpp @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +std::pair, + py::array_t> meshVerticeInpaint_smooth(py::array_t texture, +py::array_t mask, + py::array_t vtx_pos, py::array_t vtx_uv, + py::array_t pos_idx, py::array_t uv_idx) { + auto texture_buf = texture.request(); + auto mask_buf = mask.request(); + auto vtx_pos_buf = vtx_pos.request(); + auto vtx_uv_buf = vtx_uv.request(); + auto pos_idx_buf = pos_idx.request(); + auto uv_idx_buf = uv_idx.request(); + + int texture_height = texture_buf.shape[0]; + int texture_width = texture_buf.shape[1]; + int texture_channel = texture_buf.shape[2]; + float* texture_ptr = static_cast(texture_buf.ptr); + uint8_t* mask_ptr = static_cast(mask_buf.ptr); + + int vtx_num = vtx_pos_buf.shape[0]; + float* vtx_pos_ptr = static_cast(vtx_pos_buf.ptr); + float* vtx_uv_ptr = static_cast(vtx_uv_buf.ptr); + int* pos_idx_ptr = static_cast(pos_idx_buf.ptr); + int* uv_idx_ptr = static_cast(uv_idx_buf.ptr); + + vector vtx_mask(vtx_num, 0.0f); + vector> vtx_color(vtx_num, vector(texture_channel, 0.0f)); + vector uncolored_vtxs; + + vector> G(vtx_num); + + for (int i = 0; i < uv_idx_buf.shape[0]; ++i) { + for (int k = 0; k < 3; ++k) { + int vtx_uv_idx = uv_idx_ptr[i * 3 + k]; + int vtx_idx = pos_idx_ptr[i * 3 + k]; + int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1)); + int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1)); + + if (mask_ptr[uv_u * texture_width + uv_v] > 0) { + vtx_mask[vtx_idx] = 1.0f; + for (int c = 0; c < texture_channel; ++c) { + vtx_color[vtx_idx][c] = texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c]; + } + }else{ + uncolored_vtxs.push_back(vtx_idx); + } + + G[pos_idx_ptr[i * 3 + k]].push_back(pos_idx_ptr[i * 3 + (k + 1) % 3]); + } + } + + int smooth_count = 2; + int last_uncolored_vtx_count = 0; + while (smooth_count>0) { + int uncolored_vtx_count = 0; + + for (int vtx_idx : uncolored_vtxs) { + + vector sum_color(texture_channel, 0.0f); + float total_weight = 0.0f; + + array vtx_0 = {vtx_pos_ptr[vtx_idx * 3], +vtx_pos_ptr[vtx_idx * 3 + 1], vtx_pos_ptr[vtx_idx * 3 + 2]}; + for (int connected_idx : G[vtx_idx]) { + if (vtx_mask[connected_idx] > 0) { + array vtx1 = {vtx_pos_ptr[connected_idx * 3], + vtx_pos_ptr[connected_idx * 3 + 1], vtx_pos_ptr[connected_idx * 3 + 2]}; + float dist_weight = 1.0f / max(sqrt(pow(vtx_0[0] - vtx1[0], 2) + pow(vtx_0[1] - vtx1[1], 2) + \ + pow(vtx_0[2] - vtx1[2], 2)), 1E-4); + dist_weight = dist_weight * dist_weight; + for (int c = 0; c < texture_channel; ++c) { + sum_color[c] += vtx_color[connected_idx][c] * dist_weight; + } + total_weight += dist_weight; + } + } + + if (total_weight > 0.0f) { + for (int c = 0; c < texture_channel; ++c) { + vtx_color[vtx_idx][c] = sum_color[c] / total_weight; + } + vtx_mask[vtx_idx] = 1.0f; + } else { + uncolored_vtx_count++; + } + + } + + if(last_uncolored_vtx_count==uncolored_vtx_count){ + smooth_count--; + }else{ + smooth_count++; + } + last_uncolored_vtx_count = uncolored_vtx_count; + } + + // Create new arrays for the output + py::array_t new_texture(texture_buf.size); + py::array_t new_mask(mask_buf.size); + + auto new_texture_buf = new_texture.request(); + auto new_mask_buf = new_mask.request(); + + float* new_texture_ptr = static_cast(new_texture_buf.ptr); + uint8_t* new_mask_ptr = static_cast(new_mask_buf.ptr); + // Copy original texture and mask to new arrays + std::copy(texture_ptr, texture_ptr + texture_buf.size, new_texture_ptr); + std::copy(mask_ptr, mask_ptr + mask_buf.size, new_mask_ptr); + + for (int face_idx = 0; face_idx < uv_idx_buf.shape[0]; ++face_idx) { + for (int k = 0; k < 3; ++k) { + int vtx_uv_idx = uv_idx_ptr[face_idx * 3 + k]; + int vtx_idx = pos_idx_ptr[face_idx * 3 + k]; + + if (vtx_mask[vtx_idx] == 1.0f) { + int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1)); + int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1)); + + for (int c = 0; c < texture_channel; ++c) { + new_texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c] = vtx_color[vtx_idx][c]; + } + new_mask_ptr[uv_u * texture_width + uv_v] = 255; + } + } + } + + // Reshape the new arrays to match the original texture and mask shapes + new_texture.resize({texture_height, texture_width, 3}); + new_mask.resize({texture_height, texture_width}); + return std::make_pair(new_texture, new_mask); +} + + +std::pair, py::array_t> meshVerticeInpaint(py::array_t texture, + py::array_t mask, + py::array_t vtx_pos, py::array_t vtx_uv, + py::array_t pos_idx, py::array_t uv_idx, const std::string& method = "smooth") { + if (method == "smooth") { + return meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx); + } else { + throw std::invalid_argument("Invalid method. Use 'smooth' or 'forward'."); + } +} + +PYBIND11_MODULE(mesh_processor, m) { + m.def("meshVerticeInpaint", &meshVerticeInpaint, "A function to process mesh", + py::arg("texture"), py::arg("mask"), + py::arg("vtx_pos"), py::arg("vtx_uv"), + py::arg("pos_idx"), py::arg("uv_idx"), + py::arg("method") = "smooth"); +} \ No newline at end of file diff --git a/step1x3d_texture/differentiable_renderer/mesh_processor.py b/step1x3d_texture/differentiable_renderer/mesh_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..76e9a081615606b54ba5a7666f289aa38c876c1f --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/mesh_processor.py @@ -0,0 +1,90 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import numpy as np + + +def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx): + texture_height, texture_width, texture_channel = texture.shape + vtx_num = vtx_pos.shape[0] + + vtx_mask = np.zeros(vtx_num, dtype=np.float32) + vtx_color = [np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num)] + uncolored_vtxs = [] + G = [[] for _ in range(vtx_num)] + + for i in range(uv_idx.shape[0]): + for k in range(3): + vtx_uv_idx = uv_idx[i, k] + vtx_idx = pos_idx[i, k] + uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1))) + uv_u = int(round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))) + if mask[uv_u, uv_v] > 0: + vtx_mask[vtx_idx] = 1.0 + vtx_color[vtx_idx] = texture[uv_u, uv_v] + else: + uncolored_vtxs.append(vtx_idx) + G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3]) + + smooth_count = 2 + last_uncolored_vtx_count = 0 + while smooth_count > 0: + uncolored_vtx_count = 0 + for vtx_idx in uncolored_vtxs: + sum_color = np.zeros(texture_channel, dtype=np.float32) + total_weight = 0.0 + vtx_0 = vtx_pos[vtx_idx] + for connected_idx in G[vtx_idx]: + if vtx_mask[connected_idx] > 0: + vtx1 = vtx_pos[connected_idx] + dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2)) + dist_weight = 1.0 / max(dist, 1e-4) + dist_weight *= dist_weight + sum_color += vtx_color[connected_idx] * dist_weight + total_weight += dist_weight + if total_weight > 0: + vtx_color[vtx_idx] = sum_color / total_weight + vtx_mask[vtx_idx] = 1.0 + else: + uncolored_vtx_count += 1 + + if last_uncolored_vtx_count == uncolored_vtx_count: + smooth_count -= 1 + else: + smooth_count += 1 + last_uncolored_vtx_count = uncolored_vtx_count + + new_texture = texture.copy() + new_mask = mask.copy() + for face_idx in range(uv_idx.shape[0]): + for k in range(3): + vtx_uv_idx = uv_idx[face_idx, k] + vtx_idx = pos_idx[face_idx, k] + if vtx_mask[vtx_idx] == 1.0: + uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1))) + uv_u = int(round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))) + new_texture[uv_u, uv_v] = vtx_color[vtx_idx] + new_mask[uv_u, uv_v] = 255 + return new_texture, new_mask + + +def meshVerticeInpaint( + texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx, method="smooth" +): + if method == "smooth": + return meshVerticeInpaint_smooth( + texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx + ) + else: + raise ValueError("Invalid method. Use 'smooth' or 'forward'.") diff --git a/step1x3d_texture/differentiable_renderer/mesh_render.py b/step1x3d_texture/differentiable_renderer/mesh_render.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc713b1febf7108506082582938409a8e5f3018 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/mesh_render.py @@ -0,0 +1,921 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import trimesh +from PIL import Image + +from .camera_utils import ( + transform_pos, + get_mv_matrix, + get_orthographic_projection_matrix, + get_perspective_projection_matrix, +) +from .mesh_processor import meshVerticeInpaint +from .mesh_utils import load_mesh, save_mesh + + +def stride_from_shape(shape): + stride = [1] + for x in reversed(shape[1:]): + stride.append(stride[-1] * x) + return list(reversed(stride)) + + +def scatter_add_nd_with_count(input, count, indices, values, weights=None): + # input: [..., C], D dimension + C channel + # count: [..., 1], D dimension + # indices: [N, D], long + # values: [N, C] + + D = indices.shape[-1] + C = input.shape[-1] + size = input.shape[:-1] + stride = stride_from_shape(size) + + assert len(size) == D + + input = input.view(-1, C) # [HW, C] + count = count.view(-1, 1) + + flatten_indices = ( + indices * torch.tensor(stride, dtype=torch.long, device=indices.device) + ).sum( + -1 + ) # [N] + + if weights is None: + weights = torch.ones_like(values[..., :1]) + + input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) + count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) + + return input.view(*size, C), count.view(*size, 1) + + +def linear_grid_put_2d(H, W, coords, values, return_count=False): + # coords: [N, 2], float in [0, 1] + # values: [N, C] + + C = values.shape[-1] + + indices = coords * torch.tensor( + [H - 1, W - 1], dtype=torch.float32, device=coords.device + ) + indices_00 = indices.floor().long() # [N, 2] + indices_00[:, 0].clamp_(0, H - 2) + indices_00[:, 1].clamp_(0, W - 2) + indices_01 = indices_00 + torch.tensor( + [0, 1], dtype=torch.long, device=indices.device + ) + indices_10 = indices_00 + torch.tensor( + [1, 0], dtype=torch.long, device=indices.device + ) + indices_11 = indices_00 + torch.tensor( + [1, 1], dtype=torch.long, device=indices.device + ) + + h = indices[..., 0] - indices_00[..., 0].float() + w = indices[..., 1] - indices_00[..., 1].float() + w_00 = (1 - h) * (1 - w) + w_01 = (1 - h) * w + w_10 = h * (1 - w) + w_11 = h * w + + result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] + count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] + weights = torch.ones_like(values[..., :1]) # [N, 1] + + result, count = scatter_add_nd_with_count( + result, + count, + indices_00, + values * w_00.unsqueeze(1), + weights * w_00.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_01, + values * w_01.unsqueeze(1), + weights * w_01.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_10, + values * w_10.unsqueeze(1), + weights * w_10.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_11, + values * w_11.unsqueeze(1), + weights * w_11.unsqueeze(1), + ) + + if return_count: + return result, count + + mask = count.squeeze(-1) > 0 + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + + +class MeshRender: + def __init__( + self, + camera_distance=1.8, + camera_type="orth", + default_resolution=1024, + texture_size=1024, + use_antialias=True, + max_mip_level=None, + filter_mode="linear", + bake_mode="linear", + raster_mode="cr", + device="cuda", + ): + + self.device = device + + self.set_default_render_resolution(default_resolution) + self.set_default_texture_resolution(texture_size) + + self.camera_distance = camera_distance + self.use_antialias = use_antialias + self.max_mip_level = max_mip_level + self.filter_mode = filter_mode + + self.bake_angle_thres = 75 + self.bake_unreliable_kernel_size = int( + (2 / 512) * max(self.default_resolution[0], self.default_resolution[1]) + ) + self.bake_mode = bake_mode + + self.raster_mode = raster_mode + if self.raster_mode == "cr": + import custom_rasterizer as cr + + self.raster = cr + else: + raise f"No raster named {self.raster_mode}" + + if camera_type == "orth": + self.ortho_scale = 1.1 + self.camera_proj_mat = get_orthographic_projection_matrix( + left=-self.ortho_scale * 0.5, + right=self.ortho_scale * 0.5, + bottom=-self.ortho_scale * 0.5, + top=self.ortho_scale * 0.5, + near=0.1, + far=100, + ) + elif camera_type == "perspective": + self.camera_proj_mat = get_perspective_projection_matrix( + 49.13, + self.default_resolution[1] / self.default_resolution[0], + 0.01, + 100.0, + ) + else: + raise f"No camera type {camera_type}" + + def raster_rasterize(self, pos, tri, resolution, ranges=None, grad_db=True): + + if self.raster_mode == "cr": + rast_out_db = None + if pos.dim() == 2: + pos = pos.unsqueeze(0) + findices, barycentric = self.raster.rasterize(pos, tri, resolution) + rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1) + rast_out = rast_out.unsqueeze(0) + else: + raise f"No raster named {self.raster_mode}" + + return rast_out, rast_out_db + + def raster_interpolate(self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None): + + if self.raster_mode == "cr": + textd = None + barycentric = rast_out[0, ..., :-1] + findices = rast_out[0, ..., -1] + if uv.dim() == 2: + uv = uv.unsqueeze(0) + textc = self.raster.interpolate(uv, findices, barycentric, uv_idx) + else: + raise f"No raster named {self.raster_mode}" + + return textc, textd + + def raster_texture( + self, + tex, + uv, + uv_da=None, + mip_level_bias=None, + mip=None, + filter_mode="auto", + boundary_mode="wrap", + max_mip_level=None, + ): + + if self.raster_mode == "cr": + raise f"Texture is not implemented in cr" + else: + raise f"No raster named {self.raster_mode}" + + return color + + def raster_antialias( + self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0 + ): + + if self.raster_mode == "cr": + # Antialias has not been supported yet + color = color + else: + raise f"No raster named {self.raster_mode}" + + return color + + def load_mesh( + self, + mesh, + scale_factor=1.15, + auto_center=True, + ): + vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh) + self.mesh_copy = mesh + self.set_mesh( + vtx_pos, + pos_idx, + vtx_uv=vtx_uv, + uv_idx=uv_idx, + scale_factor=scale_factor, + auto_center=auto_center, + ) + if texture_data is not None: + self.set_texture(texture_data) + + def save_mesh(self): + texture_data = self.get_texture() + texture_data = Image.fromarray((texture_data * 255).astype(np.uint8)) + return save_mesh(self.mesh_copy, texture_data) + + def set_mesh( + self, + vtx_pos, + pos_idx, + vtx_uv=None, + uv_idx=None, + scale_factor=1.15, + auto_center=True, + ): + + self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float() + self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int) + if (vtx_uv is not None) and (uv_idx is not None): + self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float() + self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int) + else: + self.vtx_uv = None + self.uv_idx = None + + self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]] + self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]] + if (vtx_uv is not None) and (uv_idx is not None): + self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1] + + if auto_center: + max_bb = (self.vtx_pos - 0).max(0)[0] + min_bb = (self.vtx_pos - 0).min(0)[0] + center = (max_bb + min_bb) / 2 + scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0 + self.vtx_pos = (self.vtx_pos - center) * (scale_factor / float(scale)) + self.scale_factor = scale_factor + + def set_texture(self, tex): + if isinstance(tex, np.ndarray): + tex = Image.fromarray((tex * 255).astype(np.uint8)) + elif isinstance(tex, torch.Tensor): + tex = tex.cpu().numpy() + tex = Image.fromarray((tex * 255).astype(np.uint8)) + + tex = tex.resize(self.texture_size).convert("RGB") + tex = np.array(tex) / 255.0 + self.tex = torch.from_numpy(tex).to(self.device) + self.tex = self.tex.float() + + def set_default_render_resolution(self, default_resolution): + if isinstance(default_resolution, int): + default_resolution = (default_resolution, default_resolution) + self.default_resolution = default_resolution + + def set_default_texture_resolution(self, texture_size): + if isinstance(texture_size, int): + texture_size = (texture_size, texture_size) + self.texture_size = texture_size + + def get_mesh(self): + vtx_pos = self.vtx_pos.cpu().numpy() + pos_idx = self.pos_idx.cpu().numpy() + vtx_uv = self.vtx_uv.cpu().numpy() + uv_idx = self.uv_idx.cpu().numpy() + + # 坐标变换的逆变换 + vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]] + vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]] + + vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1] + return vtx_pos, pos_idx, vtx_uv, uv_idx + + def get_texture(self): + return self.tex.cpu().numpy() + + def to(self, device): + self.device = device + + for attr_name in dir(self): + attr_value = getattr(self, attr_name) + if isinstance(attr_value, torch.Tensor): + setattr(self, attr_name, attr_value.to(self.device)) + + def color_rgb_to_srgb(self, image): + if isinstance(image, Image.Image): + image_rgb = torch.tesnor(np.array(image) / 255.0).float().to(self.device) + elif isinstance(image, np.ndarray): + image_rgb = torch.tensor(image).float() + else: + image_rgb = image.to(self.device) + + image_srgb = torch.where( + image_rgb <= 0.0031308, + 12.92 * image_rgb, + 1.055 * torch.pow(image_rgb, 1 / 2.4) - 0.055, + ) + + if isinstance(image, Image.Image): + image_srgb = Image.fromarray( + (image_srgb.cpu().numpy() * 255).astype(np.uint8) + ) + elif isinstance(image, np.ndarray): + image_srgb = image_srgb.cpu().numpy() + else: + image_srgb = image_srgb.to(image.device) + + return image_srgb + + def _render( + self, + mvp, + pos, + pos_idx, + uv, + uv_idx, + tex, + resolution, + max_mip_level, + keep_alpha, + filter_mode, + ): + pos_clip = transform_pos(mvp, pos) + if isinstance(resolution, (int, float)): + resolution = [resolution, resolution] + rast_out, rast_out_db = self.raster_rasterize( + pos_clip, pos_idx, resolution=resolution + ) + + tex = tex.contiguous() + if filter_mode == "linear-mipmap-linear": + texc, texd = self.raster_interpolate( + uv[None, ...], rast_out, uv_idx, rast_db=rast_out_db, diff_attrs="all" + ) + color = self.raster_texture( + tex[None, ...], + texc, + texd, + filter_mode="linear-mipmap-linear", + max_mip_level=max_mip_level, + ) + else: + texc, _ = self.raster_interpolate(uv[None, ...], rast_out, uv_idx) + color = self.raster_texture(tex[None, ...], texc, filter_mode=filter_mode) + + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + color = color * visible_mask # Mask out background. + if self.use_antialias: + color = self.raster_antialias(color, rast_out, pos_clip, pos_idx) + + if keep_alpha: + color = torch.cat([color, visible_mask], dim=-1) + return color[0, ...] + + def render( + self, + elev, + azim, + camera_distance=None, + center=None, + resolution=None, + tex=None, + keep_alpha=True, + bgcolor=None, + filter_mode=None, + return_type="th", + ): + + proj = self.camera_proj_mat + r_mv = get_mv_matrix( + elev=elev, + azim=azim, + camera_distance=( + self.camera_distance if camera_distance is None else camera_distance + ), + center=center, + ) + r_mvp = np.matmul(proj, r_mv).astype(np.float32) + if tex is not None: + if isinstance(tex, Image.Image): + tex = torch.tensor(np.array(tex) / 255.0) + elif isinstance(tex, np.ndarray): + tex = torch.tensor(tex) + if tex.dim() == 2: + tex = tex.unsqueeze(-1) + tex = tex.float().to(self.device) + image = self._render( + r_mvp, + self.vtx_pos, + self.pos_idx, + self.vtx_uv, + self.uv_idx, + self.tex if tex is None else tex, + self.default_resolution if resolution is None else resolution, + self.max_mip_level, + True, + filter_mode if filter_mode else self.filter_mode, + ) + mask = (image[..., [-1]] == 1).float() + if bgcolor is None: + bgcolor = [0 for _ in range(image.shape[-1] - 1)] + image = image * mask + (1 - mask) * torch.tensor(bgcolor + [0]).to(self.device) + if keep_alpha == False: + image = image[..., :-1] + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.squeeze(-1).cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + return image + + def render_normal( + self, + elev, + azim, + camera_distance=None, + center=None, + resolution=None, + bg_color=[1, 1, 1], + use_abs_coor=False, + normalize_rgb=True, + return_type="th", + ): + + pos_camera, pos_clip = self.get_pos_from_mvp( + elev, azim, camera_distance, center + ) + if resolution is None: + resolution = self.default_resolution + if isinstance(resolution, (int, float)): + resolution = [resolution, resolution] + rast_out, rast_out_db = self.raster_rasterize( + pos_clip, self.pos_idx, resolution=resolution + ) + + if use_abs_coor: + mesh_triangles = self.vtx_pos[self.pos_idx[:, :3], :] + else: + pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] + mesh_triangles = pos_camera[self.pos_idx[:, :3], :] + face_normals = F.normalize( + torch.cross( + mesh_triangles[:, 1, :] - mesh_triangles[:, 0, :], + mesh_triangles[:, 2, :] - mesh_triangles[:, 0, :], + dim=-1, + ), + dim=-1, + ) + + vertex_normals = trimesh.geometry.mean_vertex_normals( + vertex_count=self.vtx_pos.shape[0], + faces=self.pos_idx.cpu(), + face_normals=face_normals.cpu(), + ) + vertex_normals = ( + torch.from_numpy(vertex_normals).float().to(self.device).contiguous() + ) + + # Interpolate normal values across the rasterized pixels + normal, _ = self.raster_interpolate( + vertex_normals[None, ...], rast_out, self.pos_idx + ) + + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + normal = normal * visible_mask + torch.tensor( + bg_color, dtype=torch.float32, device=self.device + ) * (1 - visible_mask) + + if normalize_rgb: + normal = (normal + 1) * 0.5 + if self.use_antialias: + normal = self.raster_antialias(normal, rast_out, pos_clip, self.pos_idx) + + image = normal[0, ...] + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + + return image + + def convert_normal_map(self, image): + # blue is front, red is left, green is top + if isinstance(image, Image.Image): + image = np.array(image) + mask = (image == [255, 255, 255]).all(axis=-1) + + image = (image / 255.0) * 2.0 - 1.0 + + image[..., [1]] = -image[..., [1]] + image[..., [1, 2]] = image[..., [2, 1]] + image[..., [0]] = -image[..., [0]] + + image = (image + 1.0) * 0.5 + + image = (image * 255).astype(np.uint8) + image[mask] = [127, 127, 255] + + return Image.fromarray(image) + + def get_pos_from_mvp(self, elev, azim, camera_distance, center): + proj = self.camera_proj_mat + r_mv = get_mv_matrix( + elev=elev, + azim=azim, + camera_distance=( + self.camera_distance if camera_distance is None else camera_distance + ), + center=center, + ) + + pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) + pos_clip = transform_pos(proj, pos_camera) + + return pos_camera, pos_clip + + def render_depth( + self, + elev, + azim, + camera_distance=None, + center=None, + resolution=None, + return_type="th", + ): + pos_camera, pos_clip = self.get_pos_from_mvp( + elev, azim, camera_distance, center + ) + + if resolution is None: + resolution = self.default_resolution + if isinstance(resolution, (int, float)): + resolution = [resolution, resolution] + rast_out, rast_out_db = self.raster_rasterize( + pos_clip, self.pos_idx, resolution=resolution + ) + + pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] + tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous() + + # Interpolate depth values across the rasterized pixels + depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx) + + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + depth_max, depth_min = ( + depth[visible_mask > 0].max(), + depth[visible_mask > 0].min(), + ) + depth = (depth - depth_min) / (depth_max - depth_min) + + depth = depth * visible_mask # Mask out background. + if self.use_antialias: + depth = self.raster_antialias(depth, rast_out, pos_clip, self.pos_idx) + + image = depth[0, ...] + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.squeeze(-1).cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + return image + + def render_position( + self, + elev, + azim, + camera_distance=None, + center=None, + resolution=None, + bg_color=[1, 1, 1], + return_type="th", + ): + pos_camera, pos_clip = self.get_pos_from_mvp( + elev, azim, camera_distance, center + ) + if resolution is None: + resolution = self.default_resolution + if isinstance(resolution, (int, float)): + resolution = [resolution, resolution] + rast_out, rast_out_db = self.raster_rasterize( + pos_clip, self.pos_idx, resolution=resolution + ) + + tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor + tex_position = tex_position.contiguous() + + # Interpolate depth values across the rasterized pixels + position, _ = self.raster_interpolate( + tex_position[None, ...], rast_out, self.pos_idx + ) + + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + + position = position * visible_mask + torch.tensor( + bg_color, dtype=torch.float32, device=self.device + ) * (1 - visible_mask) + if self.use_antialias: + position = self.raster_antialias(position, rast_out, pos_clip, self.pos_idx) + + image = position[0, ...] + + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.squeeze(-1).cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + return image + + def render_uvpos(self, return_type="th"): + image = self.uv_feature_map(self.vtx_pos * 0.5 + 0.5) + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + return image + + def uv_feature_map(self, vert_feat, bg=None): + vtx_uv = self.vtx_uv * 2 - 1.0 + vtx_uv = torch.cat([vtx_uv, torch.zeros_like(self.vtx_uv)], dim=1).unsqueeze(0) + vtx_uv[..., -1] = 1 + uv_idx = self.uv_idx + rast_out, rast_out_db = self.raster_rasterize( + vtx_uv, uv_idx, resolution=self.texture_size + ) + feat_map, _ = self.raster_interpolate(vert_feat[None, ...], rast_out, uv_idx) + feat_map = feat_map[0, ...] + if bg is not None: + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] + feat_map[visible_mask == 0] = bg + return feat_map + + def render_sketch_from_geometry(self, normal_image, depth_image): + normal_image_np = normal_image.cpu().numpy() + depth_image_np = depth_image.cpu().numpy() + + normal_image_np = (normal_image_np * 255).astype(np.uint8) + depth_image_np = (depth_image_np * 255).astype(np.uint8) + normal_image_np = cv2.cvtColor(normal_image_np, cv2.COLOR_RGB2GRAY) + + normal_edges = cv2.Canny(normal_image_np, 80, 150) + depth_edges = cv2.Canny(depth_image_np, 30, 80) + + combined_edges = np.maximum(normal_edges, depth_edges) + + sketch_image = ( + torch.from_numpy(combined_edges).to(normal_image.device).float() / 255.0 + ) + sketch_image = sketch_image.unsqueeze(-1) + + return sketch_image + + def render_sketch_from_depth(self, depth_image): + depth_image_np = depth_image.cpu().numpy() + depth_image_np = (depth_image_np * 255).astype(np.uint8) + depth_edges = cv2.Canny(depth_image_np, 30, 80) + combined_edges = depth_edges + sketch_image = ( + torch.from_numpy(combined_edges).to(depth_image.device).float() / 255.0 + ) + sketch_image = sketch_image.unsqueeze(-1) + return sketch_image + + def back_project( + self, image, elev, azim, camera_distance=None, center=None, method=None + ): + if isinstance(image, Image.Image): + image = torch.tensor(np.array(image) / 255.0) + elif isinstance(image, np.ndarray): + image = torch.tensor(image) + if image.dim() == 2: + image = image.unsqueeze(-1) + image = image.float().to(self.device) + resolution = image.shape[:2] + channel = image.shape[-1] + texture = torch.zeros(self.texture_size + (channel,)).to(self.device) + cos_map = torch.zeros(self.texture_size + (1,)).to(self.device) + + proj = self.camera_proj_mat + r_mv = get_mv_matrix( + elev=elev, + azim=azim, + camera_distance=( + self.camera_distance if camera_distance is None else camera_distance + ), + center=center, + ) + pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) + pos_clip = transform_pos(proj, pos_camera) + pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] + v0 = pos_camera[self.pos_idx[:, 0], :] + v1 = pos_camera[self.pos_idx[:, 1], :] + v2 = pos_camera[self.pos_idx[:, 2], :] + face_normals = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1) + vertex_normals = trimesh.geometry.mean_vertex_normals( + vertex_count=self.vtx_pos.shape[0], + faces=self.pos_idx.cpu(), + face_normals=face_normals.cpu(), + ) + vertex_normals = ( + torch.from_numpy(vertex_normals).float().to(self.device).contiguous() + ) + tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous() + rast_out, rast_out_db = self.raster_rasterize( + pos_clip, self.pos_idx, resolution=resolution + ) + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] + + normal, _ = self.raster_interpolate( + vertex_normals[None, ...], rast_out, self.pos_idx + ) + normal = normal[0, ...] + uv, _ = self.raster_interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx) + depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx) + depth = depth[0, ...] + + depth_max, depth_min = ( + depth[visible_mask > 0].max(), + depth[visible_mask > 0].min(), + ) + depth_normalized = (depth - depth_min) / (depth_max - depth_min) + depth_image = depth_normalized * visible_mask # Mask out background. + + sketch_image = self.render_sketch_from_depth(depth_image) + + lookat = torch.tensor([[0, 0, -1]], device=self.device) + cos_image = torch.nn.functional.cosine_similarity(lookat, normal.view(-1, 3)) + cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1) + + cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi) + cos_image[cos_image < cos_thres] = 0 + + # shrink + kernel_size = self.bake_unreliable_kernel_size * 2 + 1 + kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32).to( + sketch_image.device + ) + + visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float() + visible_mask = F.conv2d(1.0 - visible_mask, kernel, padding=kernel_size // 2) + visible_mask = 1.0 - (visible_mask > 0).float() # 二值化 + visible_mask = visible_mask.squeeze(0).permute(1, 2, 0) + + sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) + sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2) + sketch_image = (sketch_image > 0).float() # 二值化 + sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) + visible_mask = visible_mask * (sketch_image < 0.5) + + cos_image[visible_mask == 0] = 0 + + method = self.bake_mode if method is None else method + + if method == "linear": + proj_mask = (visible_mask != 0).view(-1) + uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask] + image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask] + cos_image = cos_image.contiguous().view(-1, 1)[proj_mask] + sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask] + + texture = linear_grid_put_2d( + self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image + ) + cos_map = linear_grid_put_2d( + self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], cos_image + ) + boundary_map = linear_grid_put_2d( + self.texture_size[1], + self.texture_size[0], + uv[..., [1, 0]], + sketch_image, + ) + else: + raise f"No bake mode {method}" + + return texture, cos_map, boundary_map + + def bake_texture( + self, + colors, + elevs, + azims, + camera_distance=None, + center=None, + exp=6, + weights=None, + ): + for i in range(len(colors)): + if isinstance(colors[i], Image.Image): + colors[i] = torch.tensor( + np.array(colors[i]) / 255.0, device=self.device + ).float() + if weights is None: + weights = [1.0 for _ in range(colors)] + textures = [] + cos_maps = [] + for color, elev, azim, weight in zip(colors, elevs, azims, weights): + texture, cos_map, _ = self.back_project( + color, elev, azim, camera_distance, center + ) + cos_map = weight * (cos_map**exp) + textures.append(texture) + cos_maps.append(cos_map) + + texture_merge, trust_map_merge = self.fast_bake_texture(textures, cos_maps) + return texture_merge, trust_map_merge + + @torch.no_grad() + def fast_bake_texture(self, textures, cos_maps): + + channel = textures[0].shape[-1] + texture_merge = torch.zeros(self.texture_size + (channel,)).to(self.device) + trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device) + for texture, cos_map in zip(textures, cos_maps): + view_sum = (cos_map > 0).sum() + painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() + if painted_sum / view_sum > 0.99: + continue + texture_merge += texture * cos_map + trust_map_merge += cos_map + texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) + + return texture_merge, trust_map_merge > 1e-8 + + def uv_inpaint(self, texture, mask): + + if isinstance(texture, torch.Tensor): + texture_np = texture.cpu().numpy() + elif isinstance(texture, np.ndarray): + texture_np = texture + elif isinstance(texture, Image.Image): + texture_np = np.array(texture) / 255.0 + + vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh() + + texture_np, mask = meshVerticeInpaint( + texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx + ) + + texture_np = cv2.inpaint( + (texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS + ) + + return texture_np diff --git a/step1x3d_texture/differentiable_renderer/mesh_utils.py b/step1x3d_texture/differentiable_renderer/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e27d67dd858e947b96d2e0a6d26f4789a6e33a97 --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/mesh_utils.py @@ -0,0 +1,38 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import trimesh + + +def load_mesh(mesh): + vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None + pos_idx = mesh.faces if hasattr(mesh, "faces") else None + + vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None + uv_idx = mesh.faces if hasattr(mesh, "faces") else None + + texture_data = None + + return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data + + +def save_mesh(mesh, texture_data): + material = trimesh.visual.texture.SimpleMaterial( + image=texture_data, diffuse=(255, 255, 255) + ) + texture_visuals = trimesh.visual.TextureVisuals( + uv=mesh.visual.uv, image=texture_data, material=material + ) + mesh.visual = texture_visuals + return mesh diff --git a/step1x3d_texture/differentiable_renderer/setup.py b/step1x3d_texture/differentiable_renderer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..5975f810913dbf7376ab03bc2b2a37f0b200815f --- /dev/null +++ b/step1x3d_texture/differentiable_renderer/setup.py @@ -0,0 +1,90 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from setuptools import setup, Extension +import pybind11 +import sys +import platform + + +def get_platform_specific_args(): + system = platform.system().lower() + cpp_std = "c++14" # Make configurable if needed + + if sys.platform == "win32": + compile_args = [ + "/O2", + f"/std:{cpp_std}", + "/EHsc", + "/MP", + "/DWIN32_LEAN_AND_MEAN", + "/bigobj", + ] + link_args = [] + extra_includes = [] + elif system == "linux": + compile_args = [ + "-O3", + f"-std={cpp_std}", + "-fPIC", + "-Wall", + "-Wextra", + "-pthread", + ] + link_args = ["-fPIC", "-pthread"] + extra_includes = [] + elif sys.platform == "darwin": + compile_args = [ + "-O3", + f"-std={cpp_std}", + "-fPIC", + "-Wall", + "-Wextra", + "-stdlib=libc++", + "-mmacosx-version-min=10.14", + ] + link_args = [ + "-fPIC", + "-stdlib=libc++", + "-mmacosx-version-min=10.14", + "-dynamiclib", + ] + extra_includes = [] + else: + raise RuntimeError(f"Unsupported platform: {system}") + + return compile_args, link_args, extra_includes + + +extra_compile_args, extra_link_args, platform_includes = get_platform_specific_args() +include_dirs = [pybind11.get_include(), pybind11.get_include(user=True)] +include_dirs.extend(platform_includes) + +ext_modules = [ + Extension( + "mesh_processor", + ["mesh_processor.cpp"], + include_dirs=include_dirs, + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), +] + +setup( + name="mesh_processor", + ext_modules=ext_modules, + install_requires=["pybind11>=2.6.0"], + python_requires=">=3.6", +) diff --git a/step1x3d_texture/loaders/__init__.py b/step1x3d_texture/loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93ecab4e9cb397dc0050b42edce64daa84b74f0 --- /dev/null +++ b/step1x3d_texture/loaders/__init__.py @@ -0,0 +1 @@ +from .custom_adapter import CustomAdapterMixin diff --git a/step1x3d_texture/loaders/custom_adapter.py b/step1x3d_texture/loaders/custom_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ac0b54f07b4f1fd5e60e299ee9b672c1ea605d --- /dev/null +++ b/step1x3d_texture/loaders/custom_adapter.py @@ -0,0 +1,98 @@ +import os +from typing import Dict, Optional, Union + +import safetensors +import torch +from diffusers.utils import _get_model_file, logging +from safetensors import safe_open + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CustomAdapterMixin: + def init_custom_adapter(self, *args, **kwargs): + self._init_custom_adapter(*args, **kwargs) + + def _init_custom_adapter(self, *args, **kwargs): + raise NotImplementedError + + def load_custom_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: Optional[str] = None, + **kwargs, + ): + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + subfolder=subfolder, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + self._load_custom_adapter(state_dict) + + def _load_custom_adapter(self, state_dict): + raise NotImplementedError + + def save_custom_adapter( + self, + save_directory: Union[str, os.PathLike], + weight_name: str, + safe_serialization: bool = False, + **kwargs, + ): + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file( + weights, filename, metadata={"format": "pt"} + ) + + else: + save_function = torch.save + + # Save the model + state_dict = self._save_custom_adapter(**kwargs) + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info( + f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" + ) + + def _save_custom_adapter(self): + raise NotImplementedError diff --git a/step1x3d_texture/models/__init__.py b/step1x3d_texture/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/step1x3d_texture/models/attention_processor.py b/step1x3d_texture/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..55eebb86b2205a8344ca0239c2877f0df7cb1dcf --- /dev/null +++ b/step1x3d_texture/models/attention_processor.py @@ -0,0 +1,743 @@ +import math +from typing import Callable, List, Optional, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.models.unets import UNet2DConditionModel +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from einops import rearrange, repeat +from torch import nn + + +def default_set_attn_proc_func( + name: str, + hidden_size: int, + cross_attention_dim: Optional[int], + ori_attn_proc: object, +) -> object: + return ori_attn_proc + + +def set_unet_2d_condition_attn_processor( + unet: UNet2DConditionModel, + set_self_attn_proc_func: Callable = default_set_attn_proc_func, + set_cross_attn_proc_func: Callable = default_set_attn_proc_func, + set_custom_attn_proc_func: Callable = default_set_attn_proc_func, + set_self_attn_module_names: Optional[List[str]] = None, + set_cross_attn_module_names: Optional[List[str]] = None, + set_custom_attn_module_names: Optional[List[str]] = None, +) -> None: + do_set_processor = lambda name, module_names: ( + any([name.startswith(module_name) for module_name in module_names]) + if module_names is not None + else True + ) # prefix match + + attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + # set attn_processor by default, if module_names is None + set_self_attn_processor = do_set_processor(name, set_self_attn_module_names) + set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names) + set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names) + + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name + if is_custom: + attn_procs[name] = ( + set_custom_attn_proc_func(name, hidden_size, None, attn_processor) + if set_custom_attn_processor + else attn_processor + ) + else: + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if cross_attention_dim is None or "motion_modules" in name: + # self attention + attn_procs[name] = ( + set_self_attn_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + if set_self_attn_processor + else attn_processor + ) + else: + # cross attention + attn_procs[name] = ( + set_cross_attn_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + if set_cross_attn_processor + else attn_processor + ) + + unet.set_attn_processor(attn_procs) + + +class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0. + """ + + def __init__( + self, + query_dim: int, + inner_dim: int, + num_views: int = 1, + name: Optional[str] = None, + use_mv: bool = True, + use_ref: bool = False, + ): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + super().__init__() + + self.num_views = num_views + self.name = name # NOTE: need for image cross-attention + self.use_mv = use_mv + self.use_ref = use_ref + + if self.use_mv: + self.to_q_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_k_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_v_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_out_mv = nn.ModuleList( + [ + nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True), + nn.Dropout(0.0), + ] + ) + + if self.use_ref: + self.to_q_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_k_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_v_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_out_ref = nn.ModuleList( + [ + nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True), + nn.Dropout(0.0), + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + mv_scale: float = 1.0, + ref_hidden_states: Optional[torch.FloatTensor] = None, + ref_scale: float = 1.0, + cache_hidden_states: Optional[List[torch.FloatTensor]] = None, + use_mv: bool = True, + use_ref: bool = True, + num_views: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + """ + New args: + mv_scale (float): scale for multi-view self-attention. + ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention. + ref_scale (float): scale for image cross-attention. + cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet. + + """ + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + if num_views is not None: + self.num_views = num_views + + # NEW: cache hidden states for reference unet + if cache_hidden_states is not None: + cache_hidden_states[self.name] = hidden_states.clone() + + # NEW: whether to use multi-view attention and image cross-attention + use_mv = self.use_mv and use_mv + use_ref = self.use_ref and use_ref + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + # NEW: for decoupled multi-view attention + if use_mv: + query_mv = self.to_q_mv(hidden_states) + + # NEW: for decoupled reference cross attention + if use_ref: + query_ref = self.to_q_ref(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + ####### Decoupled multi-view self-attention ######## + if use_mv: + key_mv = self.to_k_mv(encoder_hidden_states) + value_mv = self.to_v_mv(encoder_hidden_states) + + query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim) + key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim) + value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim) + + height = width = math.isqrt(sequence_length) + + # row self-attention + query_mv = rearrange( + query_mv, + "(b nv) (ih iw) h c -> (b nv ih) iw h c", + nv=self.num_views, + ih=height, + iw=width, + ).transpose(1, 2) + key_mv = rearrange( + key_mv, + "(b nv) (ih iw) h c -> b ih (nv iw) h c", + nv=self.num_views, + ih=height, + iw=width, + ) + key_mv = ( + key_mv.repeat_interleave(self.num_views, dim=0) + .view(batch_size * height, -1, attn.heads, head_dim) + .transpose(1, 2) + ) + value_mv = rearrange( + value_mv, + "(b nv) (ih iw) h c -> b ih (nv iw) h c", + nv=self.num_views, + ih=height, + iw=width, + ) + value_mv = ( + value_mv.repeat_interleave(self.num_views, dim=0) + .view(batch_size * height, -1, attn.heads, head_dim) + .transpose(1, 2) + ) + + hidden_states_mv = F.scaled_dot_product_attention( + query_mv, + key_mv, + value_mv, + dropout_p=0.0, + is_causal=False, + ) + hidden_states_mv = rearrange( + hidden_states_mv, + "(b nv ih) h iw c -> (b nv) (ih iw) (h c)", + nv=self.num_views, + ih=height, + ) + hidden_states_mv = hidden_states_mv.to(query.dtype) + + # linear proj + hidden_states_mv = self.to_out_mv[0](hidden_states_mv) + # dropout + hidden_states_mv = self.to_out_mv[1](hidden_states_mv) + + if use_ref: + reference_hidden_states = ref_hidden_states[self.name] + + key_ref = self.to_k_ref(reference_hidden_states) + value_ref = self.to_v_ref(reference_hidden_states) + + query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + + hidden_states_ref = F.scaled_dot_product_attention( + query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False + ) + + hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_ref = hidden_states_ref.to(query.dtype) + + # linear proj + hidden_states_ref = self.to_out_ref[0](hidden_states_ref) + # dropout + hidden_states_ref = self.to_out_ref[1](hidden_states_ref) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if use_mv: + hidden_states = hidden_states + hidden_states_mv * mv_scale + + if use_ref: + hidden_states = hidden_states + hidden_states_ref * ref_scale + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def set_num_views(self, num_views: int) -> None: + self.num_views = num_views + + +class DecoupledMVRowColSelfAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0. + """ + + def __init__( + self, + query_dim: int, + inner_dim: int, + num_views: int = 1, + name: Optional[str] = None, + use_mv: bool = True, + use_ref: bool = False, + ): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + super().__init__() + + self.num_views = num_views + self.name = name # NOTE: need for image cross-attention + self.use_mv = use_mv + self.use_ref = use_ref + + if self.use_mv: + self.to_q_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_k_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_v_mv = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_out_mv = nn.ModuleList( + [ + nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True), + nn.Dropout(0.0), + ] + ) + + if self.use_ref: + self.to_q_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_k_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_v_ref = nn.Linear( + in_features=query_dim, out_features=inner_dim, bias=False + ) + self.to_out_ref = nn.ModuleList( + [ + nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True), + nn.Dropout(0.0), + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + mv_scale: float = 1.0, + ref_hidden_states: Optional[torch.FloatTensor] = None, + ref_scale: float = 1.0, + cache_hidden_states: Optional[List[torch.FloatTensor]] = None, + use_mv: bool = True, + use_ref: bool = True, + num_views: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + """ + New args: + mv_scale (float): scale for multi-view self-attention. + ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention. + ref_scale (float): scale for image cross-attention. + cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet. + + """ + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + if num_views is not None: + self.num_views = num_views + + # NEW: cache hidden states for reference unet + if cache_hidden_states is not None: + cache_hidden_states[self.name] = hidden_states.clone() + + # NEW: whether to use multi-view attention and image cross-attention + use_mv = self.use_mv and use_mv + use_ref = self.use_ref and use_ref + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + # NEW: for decoupled multi-view attention + if use_mv: + query_mv = self.to_q_mv(hidden_states) + + # NEW: for decoupled reference cross attention + if use_ref: + query_ref = self.to_q_ref(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + ####### Decoupled multi-view self-attention ######## + if use_mv: + key_mv = self.to_k_mv(encoder_hidden_states) + value_mv = self.to_v_mv(encoder_hidden_states) + + query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim) + key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim) + value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim) + + height = width = math.isqrt(sequence_length) + + query_mv = rearrange( + query_mv, + "(b nv) (ih iw) h c -> b nv ih iw h c", + nv=self.num_views, + ih=height, + iw=width, + ) + key_mv = rearrange( + key_mv, + "(b nv) (ih iw) h c -> b nv ih iw h c", + nv=self.num_views, + ih=height, + iw=width, + ) + value_mv = rearrange( + value_mv, + "(b nv) (ih iw) h c -> b nv ih iw h c", + nv=self.num_views, + ih=height, + iw=width, + ) + + # row-wise attention for view 0123 (front, right, back, left) + query_mv_0123 = rearrange( + query_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c" + ) + key_mv_0123 = rearrange( + key_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c" + ) + value_mv_0123 = rearrange( + value_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c" + ) + hidden_states_mv_0123 = F.scaled_dot_product_attention( + query_mv_0123, + key_mv_0123, + value_mv_0123, + dropout_p=0.0, + is_causal=False, + ) + hidden_states_mv_0123 = rearrange( + hidden_states_mv_0123, + "(b ih) h (nv iw) c -> b nv (ih iw) (h c)", + ih=height, + iw=height, + ) + + # col-wise attention for view 0245 (front, back, top, bottom) + # flip first + query_mv_0245 = torch.cat( + [ + torch.flip(query_mv[:, [0]], [3]), # horizontal flip + query_mv[:, [2, 4, 5]], + ], + dim=1, + ) + key_mv_0245 = torch.cat( + [ + torch.flip(key_mv[:, [0]], [3]), # horizontal flip + key_mv[:, [2, 4, 5]], + ], + dim=1, + ) + value_mv_0245 = torch.cat( + [ + torch.flip(value_mv[:, [0]], [3]), # horizontal flip + value_mv[:, [2, 4, 5]], + ], + dim=1, + ) + # attention + query_mv_0245 = rearrange( + query_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c" + ) + key_mv_0245 = rearrange(key_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c") + value_mv_0245 = rearrange( + value_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c" + ) + hidden_states_mv_0245 = F.scaled_dot_product_attention( + query_mv_0245, + key_mv_0245, + value_mv_0245, + dropout_p=0.0, + is_causal=False, + ) + # flip back + hidden_states_mv_0245 = rearrange( + hidden_states_mv_0245, + "(b iw) h (nv ih) c -> b nv ih iw (h c)", + ih=height, + iw=height, + ) + hidden_states_mv_0245 = torch.cat( + [ + torch.flip(hidden_states_mv_0245[:, [0]], [3]), # horizontal flip + hidden_states_mv_0245[:, [1, 2, 3]], + ], + dim=1, + ) + hidden_states_mv_0245 = hidden_states_mv_0245.view( + hidden_states_mv_0245.shape[0], + hidden_states_mv_0245.shape[1], + -1, + hidden_states_mv_0245.shape[-1], + ) + + # combine row and col + hidden_states_mv = torch.stack( + [ + (hidden_states_mv_0123[:, 0] + hidden_states_mv_0245[:, 0]) / 2, + hidden_states_mv_0123[:, 1], + (hidden_states_mv_0123[:, 2] + hidden_states_mv_0245[:, 1]) / 2, + hidden_states_mv_0123[:, 3], + hidden_states_mv_0245[:, 2], + hidden_states_mv_0245[:, 3], + ], + dim=1, + ) + + hidden_states_mv = hidden_states_mv.view( + -1, hidden_states_mv.shape[-2], hidden_states_mv.shape[-1] + ) + hidden_states_mv = hidden_states_mv.to(query.dtype) + + # linear proj + hidden_states_mv = self.to_out_mv[0](hidden_states_mv) + # dropout + hidden_states_mv = self.to_out_mv[1](hidden_states_mv) + + if use_ref: + reference_hidden_states = ref_hidden_states[self.name] + + key_ref = self.to_k_ref(reference_hidden_states) + value_ref = self.to_v_ref(reference_hidden_states) + + query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + + hidden_states_ref = F.scaled_dot_product_attention( + query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False + ) + + hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_ref = hidden_states_ref.to(query.dtype) + + # linear proj + hidden_states_ref = self.to_out_ref[0](hidden_states_ref) + # dropout + hidden_states_ref = self.to_out_ref[1](hidden_states_ref) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if use_mv: + hidden_states = hidden_states + hidden_states_mv * mv_scale + + if use_ref: + hidden_states = hidden_states + hidden_states_ref * ref_scale + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def set_num_views(self, num_views: int) -> None: + self.num_views = num_views diff --git a/step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py b/step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..299a95e9e4915934f16f5f7d3dd8018b7d43ff1a --- /dev/null +++ b/step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py @@ -0,0 +1,962 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.models import ( + AutoencoderKL, + ImageProjection, + T2IAdapter, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import ( + StableDiffusionXLPipelineOutput, +) +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipeline, + rescale_noise_cfg, + retrieve_timesteps, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ..loaders import CustomAdapterMixin +from ..models.attention_processor import ( + DecoupledMVRowSelfAttnProcessor2_0, + set_unet_2d_condition_attn_processor, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: Optional[torch.Generator] = None, + sample_mode: str = "sample", +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, + add_watermarker=add_watermarker, + ) + + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_convert_rgb=True, + do_normalize=False, + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents + def prepare_image_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + add_noise=True, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if ( + hasattr(self.vae.config, "latents_mean") + and self.vae.config.latents_mean is not None + ): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if ( + hasattr(self.vae.config, "latents_std") + and self.vae.config.latents_std is not None + ): + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = ( + (init_latents - latents_mean) + * self.vae.config.scaling_factor + / latents_std + ) + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat( + [init_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + num_empty_images=0, # for concat in batch like ImageDream + ): + assert hasattr( + self, "control_image_processor" + ), "control_image_processor is not initialized" + + image = self.control_image_processor.preprocess( + image, height=height, width=width + ).to(dtype=torch.float32) + + if num_empty_images > 0: + image = torch.cat( + [image, torch.zeros_like(image[:num_empty_images])], dim=0 + ) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt # always 1 for control image + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + # NEW + mv_scale: float = 1.0, + # Camera or geometry condition + control_image: Optional[PipelineImageInput] = None, + control_conditioning_scale: Optional[float] = 1.0, + control_conditioning_factor: float = 1.0, + # Image condition + reference_image: Optional[PipelineImageInput] = None, + reference_conditioning_scale: Optional[float] = 1.0, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # Preprocess reference image + reference_image = self.image_processor.preprocess(reference_image) + reference_latents = self.prepare_image_latents( + reference_image, + timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use + batch_size, + 1, + prompt_embeds.dtype, + device, + generator, + add_noise=False, + ) + + with torch.no_grad(): + ref_timesteps = torch.zeros_like(timesteps[0]) + ref_hidden_states = {} + + self.unet( + reference_latents, + ref_timesteps, + encoder_hidden_states=prompt_embeds[-1:], + added_cond_kwargs={ + "text_embeds": add_text_embeds[-1:], + "time_ids": add_time_ids[-1:], + }, + cross_attention_kwargs={ + "cache_hidden_states": ref_hidden_states, + "use_mv": False, + "use_ref": False, + }, + return_dict=False, + ) + ref_hidden_states = { + k: v.repeat_interleave(num_images_per_prompt, dim=0) + for k, v in ref_hidden_states.items() + } + if self.do_classifier_free_guidance: + ref_hidden_states = { + k: torch.cat([torch.zeros_like(v), v], dim=0) + for k, v in ref_hidden_states.items() + } + + cross_attention_kwargs = { + "mv_scale": mv_scale, + "ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()}, + "ref_scale": reference_conditioning_scale, + "num_views": num_images_per_prompt, + **(self.cross_attention_kwargs or {}), + } + + # Preprocess control image + control_image_feature = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=1, # NOTE: always 1 for control images + device=device, + dtype=latents.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + control_image_feature = control_image_feature.to( + device=device, dtype=latents.dtype + ) + + adapter_state = self.cond_encoder(control_image_feature) + for i, state in enumerate(adapter_state): + adapter_state[i] = state * control_conditioning_scale + + # 8. Denoising loop + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + if i < int(num_inference_steps * control_conditioning_factor): + down_intrablock_additional_residuals = [ + state.clone() for state in adapter_state + ] + else: + down_intrablock_additional_residuals = None + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = ( + self.vae.dtype == torch.float16 and self.vae.config.force_upcast + ) + + if needs_upcasting: + self.upcast_vae() + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(self.vae.config, "latents_mean") + and self.vae.config.latents_mean is not None + ) + has_latents_std = ( + hasattr(self.vae.config, "latents_std") + and self.vae.config.latents_std is not None + ) + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, 4, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, 4, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = ( + latents * latents_std / self.vae.config.scaling_factor + + latents_mean + ) + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + ### NEW: adapters ### + def _init_custom_adapter( + self, + # Multi-view adapter + num_views: int = 1, + self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0, + # Condition encoder + cond_in_channels: int = 6, + # For training + copy_attn_weights: bool = True, + zero_init_module_keys: List[str] = [], + ): + # Condition encoder + self.cond_encoder = T2IAdapter( + in_channels=cond_in_channels, + channels=(320, 640, 1280, 1280), + num_res_blocks=2, + downscale_factor=16, + adapter_type="full_adapter_xl", + ) + + # set custom attn processor for multi-view attention and image cross-attention + self.unet: UNet2DConditionModel + set_unet_2d_condition_attn_processor( + self.unet, + set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor( + query_dim=hs, + inner_dim=hs, + num_views=num_views, + name=name, + use_mv=True, + use_ref=True, + ), + set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor( + query_dim=hs, + inner_dim=hs, + num_views=num_views, + name=name, + use_mv=False, + use_ref=False, + ), + ) + + # copy decoupled attention weights from original unet + if copy_attn_weights: + state_dict = self.unet.state_dict() + for key in state_dict.keys(): + if "_mv" in key: + compatible_key = key.replace("_mv", "").replace("processor.", "") + elif "_ref" in key: + compatible_key = key.replace("_ref", "").replace("processor.", "") + else: + compatible_key = key + + is_zero_init_key = any([k in key for k in zero_init_module_keys]) + if is_zero_init_key: + state_dict[key] = torch.zeros_like(state_dict[compatible_key]) + else: + state_dict[key] = state_dict[compatible_key].clone() + self.unet.load_state_dict(state_dict) + + def _load_custom_adapter(self, state_dict): + self.unet.load_state_dict(state_dict, strict=False) + self.cond_encoder.load_state_dict(state_dict, strict=False) + + def _save_custom_adapter( + self, + include_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + ): + def include_fn(k): + is_included = False + + if include_keys is not None: + is_included = is_included or any([key in k for key in include_keys]) + if exclude_keys is not None: + is_included = is_included and not any( + [key in k for key in exclude_keys] + ) + + return is_included + + state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)} + state_dict.update(self.cond_encoder.state_dict()) + + return state_dict diff --git a/step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py b/step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..82891e432acb807c654d73fe33bebd636aff4948 --- /dev/null +++ b/step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py @@ -0,0 +1,405 @@ +import argparse + +import numpy as np +import torch +from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel +from PIL import Image +from torchvision import transforms +from tqdm import tqdm +from transformers import AutoModelForImageSegmentation + +from step1x3d_texture.models.attention_processor import ( + DecoupledMVRowColSelfAttnProcessor2_0, +) +from step1x3d_texture.pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline +from step1x3d_texture.schedulers.scheduling_shift_snr import ShiftSNRScheduler +from step1x3d_texture.utils import ( + get_orthogonal_camera, + make_image_grid, + tensor_to_image, +) +from step1x3d_texture.utils.render import NVDiffRastContextWrapper, load_mesh, render +from step1x3d_texture.differentiable_renderer.mesh_render import MeshRender +import trimesh +import xatlas +import scipy.sparse +from scipy.sparse.linalg import spsolve +from step1x3d_texture.utils.shape_post_process import ( + FaceReducer, + FloaterRemover, + DegenerateFaceRemover, +) +from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model + + +class Step1X3DTextureConfig: + def __init__(self): + # prepare pipeline params + self.base_model = "stabilityai/stable-diffusion-xl-base-1.0" + self.vae_model = "madebyollin/sdxl-vae-fp16-fix" + self.unet_model = None + self.lora_model = None + self.adapter_path = "stepfun-ai/Step1X-3D" + self.scheduler = None + self.num_views = 6 + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype = torch.float16 + self.lora_scale = None + + # run pipeline params + self.text = "high quality" + self.num_inference_steps = 50 + self.guidance_scale = 3.0 + self.seed = -1 + self.reference_conditioning_scale = 1.0 + self.negative_prompt = "watermark, ugly, deformed, noisy, blurry, low contrast" + self.azimuth_deg = [0, 45, 90, 180, 270, 315] + + # texture baker params + self.selected_camera_azims = [0, 90, 180, 270, 180, 180] + self.selected_camera_elevs = [0, 0, 0, 0, 90, -90] + self.selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05] + self.camera_distance = 1.8 + self.render_size = 2048 + self.texture_size = 2048 + self.bake_exp = 4 + self.merge_method = "fast" + + +class Step1X3DTexturePipeline: + def __init__(self, config): + self.config = config + self.mesh_render = MeshRender( + default_resolution=self.config.render_size, + texture_size=self.config.texture_size, + camera_distance=self.config.camera_distance, + ) + + self.ig2mv_pipe = self.prepare_ig2mv_pipeline( + base_model=self.config.base_model, + vae_model=self.config.vae_model, + unet_model=self.config.unet_model, + lora_model=self.config.lora_model, + adapter_path=self.config.adapter_path, + scheduler=self.config.scheduler, + num_views=self.config.num_views, + device=self.config.device, + dtype=self.config.dtype, + ) + + @classmethod + def from_pretrained(cls, model_path, subfolder): + config = Step1X3DTextureConfig() + local_model_path = smart_load_model(model_path, subfolder=subfolder) + print(f'Local model path: {local_model_path}') + config.adapter_path = local_model_path + return cls(config) + + def mesh_uv_wrap(self, mesh): + if isinstance(mesh, trimesh.Scene): + mesh = mesh.to_geometry() + vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) + mesh.vertices = mesh.vertices[vmapping] + mesh.faces = indices + mesh.visual.uv = uvs + + return mesh + + def prepare_ig2mv_pipeline( + self, + base_model, + vae_model, + unet_model, + lora_model, + adapter_path, + scheduler, + num_views, + device, + dtype, + ): + # Load vae and unet if provided + pipe_kwargs = {} + if vae_model is not None: + pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) + if unet_model is not None: + pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) + + print('VAE Loaded!') + # Prepare pipeline + pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) + + print('Base model Loaded!') + # Load scheduler if provided + scheduler_class = None + if scheduler == "ddpm": + scheduler_class = DDPMScheduler + elif scheduler == "lcm": + scheduler_class = LCMScheduler + + pipe.scheduler = ShiftSNRScheduler.from_scheduler( + pipe.scheduler, + shift_mode="interpolated", + shift_scale=8.0, + scheduler_class=scheduler_class, + ) + print('Scheduler Loaded!') + pipe.init_custom_adapter( + num_views=num_views, + self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0, + ) + print(f'Load adapter from {adapter_path}/step1x-3d-ig2v.safetensors') + pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors") + print(f'Load adapter successed!') + pipe.to(device=device, dtype=dtype) + pipe.cond_encoder.to(device=device, dtype=dtype) + + # load lora if provided + if lora_model is not None: + model_, name_ = lora_model.rsplit("/", 1) + pipe.load_lora_weights(model_, weight_name=name_) + + return pipe + + def remove_bg(self, image, net, transform, device): + image_size = image.size + input_images = transform(image).unsqueeze(0).to(device) + with torch.no_grad(): + preds = net(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + + def preprocess_image(self, image, height, width): + image = np.array(image) + alpha = image[..., 3] > 0 + H, W = alpha.shape + # get the bounding box of alpha + y, x = np.where(alpha) + y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) + x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) + image_center = image[y0:y1, x0:x1] + # resize the longer side to H * 0.9 + H, W, _ = image_center.shape + if H > W: + W = int(W * (height * 0.9) / H) + H = int(height * 0.9) + else: + H = int(H * (width * 0.9) / W) + W = int(width * 0.9) + image_center = np.array(Image.fromarray(image_center).resize((W, H))) + # pad to H, W + start_h = (height - H) // 2 + start_w = (width - W) // 2 + image = np.zeros((height, width, 4), dtype=np.uint8) + image[start_h : start_h + H, start_w : start_w + W] = image_center + image = image.astype(np.float32) / 255.0 + image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 + image = (image * 255).clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) + + return image + + def run_ig2mv_pipeline( + self, + pipe, + mesh, + num_views, + text, + image, + height, + width, + num_inference_steps, + guidance_scale, + seed, + remove_bg_fn=None, + reference_conditioning_scale=1.0, + negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", + lora_scale=1.0, + device="cuda", + ): + # Prepare cameras + cameras = get_orthogonal_camera( + elevation_deg=[0, 0, 0, 0, 89.99, -89.99], + distance=[1.8] * num_views, + left=-0.55, + right=0.55, + bottom=-0.55, + top=0.55, + azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], + device=device, + ) + ctx = NVDiffRastContextWrapper(device=device, context_type="cuda") + + mesh, mesh_bp = load_mesh(mesh, rescale=True, device=device) + render_out = render( + ctx, + mesh, + cameras, + height=height, + width=width, + render_attr=False, + normal_background=0.0, + ) + pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True) + normal_images = tensor_to_image( + (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True + ) + control_images = ( + torch.cat( + [ + (render_out.pos + 0.5).clamp(0, 1), + (render_out.normal / 2 + 0.5).clamp(0, 1), + ], + dim=-1, + ) + .permute(0, 3, 1, 2) + .to(device) + ) + + # Prepare image + reference_image = Image.open(image) if isinstance(image, str) else image + if len(reference_image.split()) == 1: + reference_image = reference_image.convert("RGBA") + if remove_bg_fn is not None and reference_image.mode == "RGB": + reference_image = remove_bg_fn(reference_image) + reference_image = self.preprocess_image(reference_image, height, width) + elif reference_image.mode == "RGBA": + reference_image = self.preprocess_image(reference_image, height, width) + + pipe_kwargs = {} + if seed != -1 and isinstance(seed, int): + pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) + + images = pipe( + text, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_views, + control_image=control_images, + control_conditioning_scale=1.0, + reference_image=reference_image, + reference_conditioning_scale=reference_conditioning_scale, + negative_prompt=negative_prompt, + cross_attention_kwargs={"scale": lora_scale}, + mesh=mesh_bp, + **pipe_kwargs, + ).images + + return images, pos_images, normal_images, reference_image, mesh, mesh_bp + + def bake_from_multiview( + self, + render, + views, + camera_elevs, + camera_azims, + view_weights, + method="graphcut", + bake_exp=4, + ): + project_textures, project_weighted_cos_maps = [], [] + project_boundary_maps = [] + for view, camera_elev, camera_azim, weight in zip( + views, camera_elevs, camera_azims, view_weights + ): + project_texture, project_cos_map, project_boundary_map = ( + render.back_project(view, camera_elev, camera_azim) + ) + project_cos_map = weight * (project_cos_map**bake_exp) + project_textures.append(project_texture) + project_weighted_cos_maps.append(project_cos_map) + project_boundary_maps.append(project_boundary_map) + + if method == "fast": + texture, ori_trust_map = render.fast_bake_texture( + project_textures, project_weighted_cos_maps + ) + elif method == "poisson": + texture = poisson_blend( + project_textures, project_weighted_cos_maps, project_boundary_maps + ) + else: + raise f"no method {method}" + return texture, ori_trust_map > 1e-8 + + def texture_inpaint(self, render, texture, mask): + texture_np = render.uv_inpaint(texture, mask) + texture = torch.tensor(texture_np / 255).float().to(texture.device) + + return texture + + @torch.no_grad() + def __call__(self, image, mesh, remove_bg=True): + if remove_bg: + birefnet = AutoModelForImageSegmentation.from_pretrained( + "ZhengPeng7/BiRefNet", trust_remote_code=True + ) + birefnet.to(self.config.device) + transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + remove_bg_fn = lambda x: self.remove_bg( + x, birefnet, transform_image, args.device + ) + else: + remove_bg_fn = None + + if isinstance(mesh, trimesh.Scene): + mesh = mesh.to_geometry() + + # multi-view generation pipeline + images, pos_images, normal_images, reference_image, textured_mesh, mesh_bp = ( + self.run_ig2mv_pipeline( + self.ig2mv_pipe, + mesh=mesh, + num_views=self.config.num_views, + text=self.config.text, + image=image, + height=768, + width=768, + num_inference_steps=self.config.num_inference_steps, + guidance_scale=self.config.guidance_scale, + seed=self.config.seed, + lora_scale=self.config.lora_scale, + reference_conditioning_scale=self.config.reference_conditioning_scale, + negative_prompt=self.config.negative_prompt, + device=self.config.device, + remove_bg_fn=remove_bg_fn, + ) + ) + + for i in range(len(images)): + images[i] = images[i].resize( + (self.config.render_size, self.config.render_size), + Image.Resampling.LANCZOS, + ) + + mesh = self.mesh_uv_wrap(mesh_bp) + self.mesh_render.load_mesh(mesh, auto_center=False, scale_factor=1.0) + + # texture baker + texture, mask = self.bake_from_multiview( + self.mesh_render, + images, + self.config.selected_camera_elevs, + self.config.selected_camera_azims, + self.config.selected_view_weights, + method="fast", + ) + mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) + + # texture inpaint + texture = self.texture_inpaint(self.mesh_render, texture, mask_np) + + self.mesh_render.set_texture(texture) + textured_mesh = self.mesh_render.save_mesh() + + return textured_mesh diff --git a/step1x3d_texture/renderer/__init__.py b/step1x3d_texture/renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/step1x3d_texture/renderer/geometry.py b/step1x3d_texture/renderer/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..6f70671d01d00663fdc0c3a7c95bb8ecfd84ec4d --- /dev/null +++ b/step1x3d_texture/renderer/geometry.py @@ -0,0 +1,151 @@ +import torch +import pytorch3d +import torch.nn.functional as F + +from pytorch3d.ops import interpolate_face_attributes + +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + AmbientLights, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + SoftSilhouetteShader, + HardPhongShader, + TexturesVertex, + TexturesUV, + Materials, +) +from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend +from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties +from pytorch3d.renderer.mesh.shader import ShaderBase + + +def get_cos_angle(points, normals, camera_position): + """ + calculate cosine similarity between view->surface and surface normal. + """ + + if points.shape != normals.shape: + msg = "Expected points and normals to have the same shape: got %r, %r" + raise ValueError(msg % (points.shape, normals.shape)) + + # Ensure all inputs have same batch dimension as points + matched_tensors = convert_to_tensors_and_broadcast( + points, camera_position, device=points.device + ) + _, camera_position = matched_tensors + + # Reshape direction and color so they have all the arbitrary intermediate + # dimensions as points. Assume first dim = batch dim and last dim = 3. + points_dims = points.shape[1:-1] + expand_dims = (-1,) + (1,) * len(points_dims) + + if camera_position.shape != normals.shape: + camera_position = camera_position.view(expand_dims + (3,)) + + normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) + + # Calculate the cosine value. + view_direction = camera_position - points + view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6) + cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True) + cos_angle = cos_angle.clamp(0, 1) + + # Cosine of the angle between the reflected light ray and the viewer + return cos_angle + + +def _geometry_shading_with_pixels( + meshes, fragments, lights, cameras, materials, texels +): + """ + Render pixel space vertex position, normal(world), depth, and cos angle + + Args: + meshes: Batch of meshes + fragments: Fragments named tuple with the outputs of rasterization + lights: Lights class containing a batch of lights + cameras: Cameras class containing a batch of cameras + materials: Materials class containing a batch of material properties + texels: texture per pixel of shape (N, H, W, K, 3) + + Returns: + colors: (N, H, W, K, 3) + pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection. + """ + verts = meshes.verts_packed() # (V, 3) + faces = meshes.faces_packed() # (F, 3) + vertex_normals = meshes.verts_normals_packed() # (V, 3) + faces_verts = verts[faces] + faces_normals = vertex_normals[faces] + pixel_coords_in_camera = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts + ) + pixel_normals = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_normals + ) + + cos_angles = get_cos_angle( + pixel_coords_in_camera, pixel_normals, cameras.get_camera_center() + ) + + return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles + + +class HardGeometryShader(ShaderBase): + """ + renders common geometric informations. + + + """ + + def forward(self, fragments, meshes, **kwargs): + cameras = super()._get_cameras(**kwargs) + texels = self.texel_from_uv(fragments, meshes) + + lights = kwargs.get("lights", self.lights) + materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) + verts, normals, depths, cos_angles = _geometry_shading_with_pixels( + meshes=meshes, + fragments=fragments, + texels=texels, + lights=lights, + cameras=cameras, + materials=materials, + ) + texels = meshes.sample_textures(fragments) + verts = hard_rgb_blend(verts, fragments, blend_params) + normals = hard_rgb_blend(normals, fragments, blend_params) + depths = hard_rgb_blend(depths, fragments, blend_params) + cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params) + from IPython import embed + + embed() + texels = hard_rgb_blend(texels, fragments, blend_params) + return verts, normals, depths, cos_angles, texels, fragments + + def texel_from_uv(self, fragments, meshes): + texture_tmp = meshes.textures + maps_tmp = texture_tmp.maps_padded() + uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]] + uv_color = ( + torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype) + ) + uv_texture = TexturesUV( + [uv_color.clone() for t in maps_tmp], + texture_tmp.faces_uvs_padded(), + texture_tmp.verts_uvs_padded(), + sampling_mode="bilinear", + ) + meshes.textures = uv_texture + texels = meshes.sample_textures(fragments) + meshes.textures = texture_tmp + texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1) + return texels diff --git a/step1x3d_texture/renderer/project.py b/step1x3d_texture/renderer/project.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4432288e2376bccaf27692fbd4ba4ae04cfc4c --- /dev/null +++ b/step1x3d_texture/renderer/project.py @@ -0,0 +1,875 @@ +import torch +import pytorch3d + + +from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + FoVOrthographicCameras, + AmbientLights, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + TexturesUV, +) + +from .geometry import HardGeometryShader +from .shader import HardNChannelFlatShader +from .voronoi import voronoi_solve +import torch.nn.functional as F +import open3d as o3d +import pdb +import kaolin as kal +import numpy as np + + +import torch +from pytorch3d.renderer.cameras import FoVOrthographicCameras +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from pytorch3d.common.datatypes import Device +import math +import torch.nn.functional as F +from trimesh import Trimesh +from pytorch3d.structures import Meshes +import os + +LIST_TYPE = Union[list, np.ndarray, torch.Tensor] + +_R = torch.eye(3)[None] # (1, 3, 3) +_T = torch.zeros(1, 3) # (1, 3) +_BatchFloatType = Union[float, Sequence[float], torch.Tensor] + + +class CustomOrthographicCameras(FoVOrthographicCameras): + def compute_projection_matrix( + self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz + ) -> torch.Tensor: + """ + 自定义正交投影矩阵计算,继承并修改深度通道参数 + 参数维度说明: + - znear/zfar: (N,) + - max_x/min_x: (N,) + - max_y/min_y: (N,) + - scale_xyz: (N, 3) + """ + K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) + + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + # NOTE: OpenGL flips handedness of coordinate system between camera + # space and NDC space so z sign is -ve. In PyTorch3D we maintain a + # right handed coordinate system throughout. + z_sign = +1.0 + + K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] + K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] + K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) + K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) + K[:, 3, 3] = ones + + # NOTE: This maps the z coordinate to the range [0, 1] and replaces the + # the OpenGL z normalization to [-1, 1] + K[:, 2, 2] = -2 * (1.0 / (zfar - znear)) * scale_xyz[:, 2] + K[:, 2, 3] = -(znear + zfar) / (zfar - znear) + + return K + + def __init__( + self, + znear: _BatchFloatType = 1.0, + zfar: _BatchFloatType = 100.0, + max_y: _BatchFloatType = 1.0, + min_y: _BatchFloatType = -1.0, + max_x: _BatchFloatType = 1.0, + min_x: _BatchFloatType = -1.0, + scale_xyz=((1.0, 1.0, 1.0),), # (N, 3) + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + ): + # 继承父类初始化逻辑 + super().__init__( + znear=znear, + zfar=zfar, + max_y=max_y, + min_y=min_y, + max_x=max_x, + min_x=min_x, + scale_xyz=scale_xyz, + R=R, + T=T, + K=K, + device=device, + ) + + +def erode_torch_batch(binary_img_batch, kernel_size): + pad = (kernel_size - 1) // 2 + bin_img = F.pad( + binary_img_batch.unsqueeze(1), pad=[pad, pad, pad, pad], mode="reflect" + ) + out = -F.max_pool2d(-bin_img, kernel_size=kernel_size, stride=1, padding=0) + out = out.squeeze(1) + return out + + +def dilate_torch_batch(binary_img_batch, kernel_size): + pad = (kernel_size - 1) // 2 + bin_img = F.pad(binary_img_batch, pad=[pad, pad, pad, pad], mode="reflect") + out = F.max_pool2d(bin_img, kernel_size=kernel_size, stride=1, padding=0) + out = out.squeeze() + return out + + +# Pytorch3D based renderering functions, managed in a class +# Render size is recommended to be the same as your latent view size +# DO NOT USE "bilinear" sampling when you are handling latents. +# Stable Diffusion has 4 latent channels so use channels=4 + + +class UVProjection: + def __init__( + self, + texture_size=96, + render_size=64, + sampling_mode="nearest", + channels=3, + device=None, + ): + self.channels = channels + self.device = device or torch.device("cpu") + self.lights = AmbientLights( + ambient_color=((1.0,) * channels,), device=self.device + ) + self.target_size = (texture_size, texture_size) + self.render_size = render_size + self.sampling_mode = sampling_mode + + # Load obj mesh, rescale the mesh to fit into the bounding box + def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False): + if isinstance(mesh, Trimesh): + vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device) + faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) + mesh = Meshes(verts=[vertices], faces=[faces]) + verts = mesh.verts_packed() + mesh = mesh.update_padded(verts[None, :, :]) + elif isinstance(mesh, str) and os.path.isfile(mesh): + mesh = load_objs_as_meshes([mesh_path], device=self.device) + if auto_center: + verts = mesh.verts_packed() + max_bb = (verts - 0).max(0)[0] + min_bb = (verts - 0).min(0)[0] + scale = (max_bb - min_bb).max() / 2 + center = (max_bb + min_bb) / 2 + mesh.offset_verts_(-center) + mesh.scale_verts_((scale_factor / float(scale))) + else: + mesh.scale_verts_((scale_factor)) + + if autouv or (mesh.textures is None): + mesh = self.uv_unwrap(mesh) + self.mesh = mesh + + def load_glb_mesh( + self, mesh_path, trimesh, scale_factor=1.0, auto_center=True, autouv=False + ): + from pytorch3d.io.experimental_gltf_io import MeshGlbFormat + + io = IO() + io.register_meshes_format(MeshGlbFormat()) + with open(mesh_path, "rb") as f: + mesh = io.load_mesh(f, include_textures=True, device=self.device) + if auto_center: + verts = mesh.verts_packed() + + max_bb = (verts - 0).max(0)[0] + min_bb = (verts - 0).min(0)[0] + scale = (max_bb - min_bb).max() / 2 + center = (max_bb + min_bb) / 2 + mesh.offset_verts_(-center) + mesh.scale_verts_((scale_factor / float(scale))) + verts = mesh.verts_packed() + # T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=verts.device, dtype=verts.dtype) + # T = torch.tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], device=verts.device, dtype=verts.dtype) + # verts = verts @ T + mesh = mesh.update_padded(verts[None, :, :]) + else: + mesh.scale_verts_((scale_factor)) + if autouv or (mesh.textures is None): + mesh = self.uv_unwrap(mesh) + self.mesh = mesh + + # Save obj mesh + def save_mesh(self, mesh_path, texture): + save_obj( + mesh_path, + self.mesh.verts_list()[0], + self.mesh.faces_list()[0], + verts_uvs=self.mesh.textures.verts_uvs_list()[0], + faces_uvs=self.mesh.textures.faces_uvs_list()[0], + texture_map=texture, + ) + + # Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git) + def uv_unwrap(self, mesh): + verts_list = mesh.verts_list()[0] + faces_list = mesh.faces_list()[0] + + import xatlas + import numpy as np + + v_np = verts_list.cpu().numpy() + f_np = faces_list.int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + vt = ( + torch.from_numpy(vt_np.astype(np.float32)) + .type(verts_list.dtype) + .to(mesh.device) + ) + ft = ( + torch.from_numpy(ft_np.astype(np.int64)) + .type(faces_list.dtype) + .to(mesh.device) + ) + + new_map = torch.zeros(self.target_size + (self.channels,), device=mesh.device) + new_tex = TexturesUV([new_map], [ft], [vt], sampling_mode=self.sampling_mode) + + mesh.textures = new_tex + return mesh + + """ + A functions that disconnect faces in the mesh according to + its UV seams. The number of vertices are made equal to the + number of unique vertices its UV layout, while the faces list + is intact. + """ + + def disconnect_faces(self): + mesh = self.mesh + verts_list = mesh.verts_list() + faces_list = mesh.faces_list() + verts_uvs_list = mesh.textures.verts_uvs_list() + faces_uvs_list = mesh.textures.faces_uvs_list() + packed_list = [v[f] for v, f in zip(verts_list, faces_list)] + verts_disconnect_list = [ + torch.zeros( + (verts_uvs_list[i].shape[0], 3), + dtype=verts_list[0].dtype, + device=verts_list[0].device, + ) + for i in range(len(verts_list)) + ] + for i in range(len(verts_list)): + verts_disconnect_list[i][faces_uvs_list] = packed_list[i] + assert not mesh.has_verts_normals(), "Not implemented for vertex normals" + self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures) + return self.mesh_d + + """ + A function that construct a temp mesh for back-projection. + Take a disconnected mesh and a rasterizer, the function calculates + the projected faces as the UV, as use its original UV with pseudo + z value as world space geometry. + """ + + def construct_uv_mesh(self): + mesh = self.mesh_d + verts_list = mesh.verts_list() + verts_uvs_list = mesh.textures.verts_uvs_list() + # faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()] + new_verts_list = [] + for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)): + verts = verts.clone() + verts_uv = verts_uv.clone() + verts[..., 0:2] = verts_uv[..., :] + verts = (verts - 0.5) * 2 + verts[..., 2] *= 1 + new_verts_list.append(verts) + textures_uv = mesh.textures.clone() + self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv) + return self.mesh_uv + + # Set texture for the current mesh. + def set_texture_map(self, texture): + new_map = texture.permute(1, 2, 0) + new_map = new_map.to(self.device) + new_tex = TexturesUV( + [new_map], + self.mesh.textures.faces_uvs_padded(), + self.mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + self.mesh.textures = new_tex + + # Set the initial normal noise texture + # No generator here for replication of the experiment result. Add one as you wish + def set_noise_texture(self, channels=None): + if not channels: + channels = self.channels + noise_texture = torch.normal( + 0, 1, (channels,) + self.target_size, device=self.device + ) + self.set_texture_map(noise_texture) + return noise_texture + + # Set the cameras given the camera poses and centers + def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None): + elev = torch.FloatTensor([pose[0] for pose in camera_poses]) + azim = torch.FloatTensor([pose[1] for pose in camera_poses]) + print("camera_distance:{}".format(camera_distance)) + R, T = look_at_view_transform( + dist=camera_distance, elev=elev, azim=azim, at=centers or ((0, 0, 0),) + ) + # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device) + # R = R@flip_mat + # R = R.permute(0, 2, 1) + # T = T*torch.from_numpy(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device) + # print("v R size:{}, v T size:{}".format(R.size(), T.size())) + # c2w = self.get_c2w(elev, [camera_distance]*len(elev), azim) + # w2c = torch.linalg.inv(c2w) + # R, T= w2c[:, :3, :3], w2c[:, :3, 3] + print("R size:{}, T size:{}".format(R.size(), T.size())) + # self.cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55) + self.cameras = FoVOrthographicCameras( + device=self.device, R=R, T=T, scale_xyz=scale or ((1, 1, 1),) + ) + + # Set all necessary internal data for rendering and texture baking + # Can be used to refresh after changing camera positions + def set_cameras_and_render_settings( + self, + camera_poses, + centers=None, + camera_distance=2.7, + render_size=None, + scale=None, + ): + self.set_cameras(camera_poses, centers, camera_distance, scale=scale) + if render_size is None: + render_size = self.render_size + if not hasattr(self, "renderer"): + self.setup_renderer(size=render_size) + if not hasattr(self, "mesh_d"): + self.disconnect_faces() + if not hasattr(self, "mesh_uv"): + self.construct_uv_mesh() + self.calculate_tex_gradient() + self.calculate_visible_triangle_mask() + _, _, _, cos_maps, _, _ = self.render_geometry() + self.calculate_cos_angle_weights(cos_maps) + + # Setup renderers for rendering + # max faces per bin set to 30000 to avoid overflow in many test cases. + # You can use default value to let pytorch3d handle that for you. + def setup_renderer( + self, + size=64, + blur=0.0, + face_per_pix=1, + perspective_correct=False, + channels=None, + ): + if not channels: + channels = self.channels + + self.raster_settings = RasterizationSettings( + image_size=size, + blur_radius=blur, + faces_per_pixel=face_per_pix, + perspective_correct=perspective_correct, + cull_backfaces=True, + max_faces_per_bin=30000, + ) + + self.renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=self.cameras, + raster_settings=self.raster_settings, + ), + shader=HardNChannelFlatShader( + device=self.device, + cameras=self.cameras, + lights=self.lights, + channels=channels, + # materials=materials + ), + ) + + # Bake screen-space cosine weights to UV space + # May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now + @torch.enable_grad() + def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None): + if not channels: + channels = self.channels + cos_maps = [] + tmp_mesh = self.mesh.clone() + for i in range(len(self.cameras)): + + zero_map = torch.zeros( + self.target_size + (channels,), device=self.device, requires_grad=True + ) + optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0) + optimizer.zero_grad() + zero_tex = TexturesUV( + [zero_map], + self.mesh.textures.faces_uvs_padded(), + self.mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + tmp_mesh.textures = zero_tex + + images_predicted = self.renderer( + tmp_mesh, cameras=self.cameras[i], lights=self.lights + ) + + loss = torch.sum((cos_angles[i, :, :, 0:1] ** 1 - images_predicted) ** 2) + loss.backward() + optimizer.step() + + if fill: + zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8) + zero_map = voronoi_solve( + zero_map, self.gradient_maps[i][..., 0], self.device + ) + else: + zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8) + cos_maps.append(zero_map) + self.cos_maps = cos_maps + + # Get geometric info from fragment shader + # Can be used for generating conditioning image and cosine weights + # Returns some information you may not need, remember to release them for memory saving + @torch.no_grad() + def render_geometry(self, image_size=None): + if image_size: + size = self.renderer.rasterizer.raster_settings.image_size + self.renderer.rasterizer.raster_settings.image_size = image_size + shader = self.renderer.shader + self.renderer.shader = HardGeometryShader( + device=self.device, cameras=self.cameras[0], lights=self.lights + ) + tmp_mesh = self.mesh.clone() + + verts, normals, depths, cos_angles, texels, fragments = self.renderer( + tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights + ) + self.renderer.shader = shader + + if image_size: + self.renderer.rasterizer.raster_settings.image_size = size + + return verts, normals, depths, cos_angles, texels, fragments + + # Project world normal to view space and normalize + @torch.no_grad() + def decode_view_normal(self, normals): + w2v_mat = self.cameras.get_full_projection_transform() + normals_view = torch.clone(normals)[:, :, :, 0:3] + normals_view = normals_view.reshape(normals_view.shape[0], -1, 3) + normals_view = w2v_mat.transform_normals(normals_view) + normals_view = normals_view.reshape(normals.shape[0:3] + (3,)) + normals_view[:, :, :, 2] *= -1 + normals = (normals_view[..., 0:3] + 1) * normals[ + ..., 3: + ] / 2 + torch.FloatTensor(((((0.5, 0.5, 1))))).to(self.device) * ( + 1 - normals[..., 3:] + ) + # normals = torch.cat([normal for normal in normals], dim=1) + normals = normals.clamp(0, 1) + return normals + + # Normalize absolute depth to inverse depth + @torch.no_grad() + def decode_normalized_depth(self, depths, batched_norm=False): + view_z, mask = depths.unbind(-1) + view_z = view_z * mask + 100 * (1 - mask) + inv_z = 1 / view_z + inv_z_min = inv_z * mask + 100 * (1 - mask) + if not batched_norm: + max_ = torch.max(inv_z, 1, keepdim=True) + max_ = torch.max(max_[0], 2, keepdim=True)[0] + + min_ = torch.min(inv_z_min, 1, keepdim=True) + min_ = torch.min(min_[0], 2, keepdim=True)[0] + else: + max_ = torch.max(inv_z) + min_ = torch.min(inv_z_min) + inv_z = (inv_z - min_) / (max_ - min_) + inv_z = inv_z.clamp(0, 1) + inv_z = inv_z[..., None].repeat(1, 1, 1, 3) + + return inv_z + + # Multiple screen pixels could pass gradient to a same texel + # We can precalculate this gradient strength and use it to normalize gradients when we bake textures + @torch.enable_grad() + def calculate_tex_gradient(self, channels=None): + if not channels: + channels = self.channels + tmp_mesh = self.mesh.clone() + gradient_maps = [] + for i in range(len(self.cameras)): + zero_map = torch.zeros( + self.target_size + (channels,), device=self.device, requires_grad=True + ) + optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0) + optimizer.zero_grad() + zero_tex = TexturesUV( + [zero_map], + self.mesh.textures.faces_uvs_padded(), + self.mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + tmp_mesh.textures = zero_tex + images_predicted = self.renderer( + tmp_mesh, cameras=self.cameras[i], lights=self.lights + ) + loss = torch.sum((1 - images_predicted) ** 2) + loss.backward() + optimizer.step() + + gradient_maps.append(zero_map.detach()) + + self.gradient_maps = gradient_maps + + # Get the UV space masks of triangles visible in each view + # First get face ids from each view, then filter pixels on UV space to generate masks + + @torch.no_grad() + def get_c2w( + self, + elevation_deg: LIST_TYPE, + distance: LIST_TYPE, + azimuth_deg: Optional[LIST_TYPE], + num_views: Optional[int] = 1, + device: Optional[str] = None, + ) -> torch.FloatTensor: + if azimuth_deg is None: + assert ( + num_views is not None + ), "num_views must be provided if azimuth_deg is None." + azimuth_deg = torch.linspace( + 0, 360, num_views + 1, dtype=torch.float32, device=device + )[:-1] + else: + num_views = len(azimuth_deg) + + def list_to_pt( + x: LIST_TYPE, + dtype: Optional[torch.dtype] = None, + device: Optional[str] = None, + ) -> torch.Tensor: + if isinstance(x, list) or isinstance(x, np.ndarray): + return torch.tensor(x, dtype=dtype, device=device) + return x.to(dtype=dtype) + + azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) + elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) + camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + camera_positions = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + center = torch.zeros_like(camera_positions) + up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[ + None, : + ].repeat(num_views, 1) + lookat = F.normalize(center - camera_positions, dim=-1) + right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) + up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) + c2w3x4 = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) + c2w[:, 3, 3] = 1.0 + return c2w + + @torch.no_grad() + def calculate_visible_triangle_mask(self, channels=None, image_size=(512, 512)): + if not channels: + channels = self.channels + + pix2face_list = [] + for i in range(len(self.cameras)): + self.renderer.rasterizer.raster_settings.image_size = image_size + pix2face = self.renderer.rasterizer( + self.mesh_d, cameras=self.cameras[i] + ).pix_to_face + self.renderer.rasterizer.raster_settings.image_size = self.render_size + pix2face_list.append(pix2face) + + if not hasattr(self, "mesh_uv"): + self.construct_uv_mesh() + + raster_settings = RasterizationSettings( + image_size=self.target_size, + blur_radius=0, + faces_per_pixel=1, + perspective_correct=False, + cull_backfaces=False, + max_faces_per_bin=30000, + ) + + R, T = look_at_view_transform(dist=2, elev=0, azim=0) + # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device) + # R = R@flip_mat + # T = T*torch.tensor(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device) + # c2w = self.get_c2w([0], [1.8], [0]) + # w2c = torch.linalg.inv(c2w)[:, :3,:] + # R, T= w2c[:, :3,:3], w2c[:, :3, 3] + # print("R size:{}, T size:{}".format(R.size(), T.size())) + cameras = FoVOrthographicCameras(device=self.device, R=R, T=T) + # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T) + + # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55) + + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + uv_pix2face = rasterizer(self.mesh_uv).pix_to_face + + visible_triangles = [] + for i in range(len(pix2face_list)): + valid_faceid = torch.unique(pix2face_list[i]) + valid_faceid = valid_faceid[1:] if valid_faceid[0] == -1 else valid_faceid + mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False) + # uv_pix2face[0][~mask] = -1 + triangle_mask = torch.ones(self.target_size + (1,), device=self.device) + triangle_mask[~mask] = 0 + + triangle_mask[:, 1:][triangle_mask[:, :-1] > 0] = 1 + triangle_mask[:, :-1][triangle_mask[:, 1:] > 0] = 1 + triangle_mask[1:, :][triangle_mask[:-1, :] > 0] = 1 + triangle_mask[:-1, :][triangle_mask[1:, :] > 0] = 1 + visible_triangles.append(triangle_mask) + + self.visible_triangles = visible_triangles + + # Render the current mesh and texture from current cameras + def render_textured_views(self): + meshes = self.mesh.extend(len(self.cameras)) + images_predicted = self.renderer( + meshes, cameras=self.cameras, lights=self.lights + ) + + return [image.permute(2, 0, 1) for image in images_predicted] + + @torch.no_grad() + def get_point_validation_by_o3d( + self, points, eye_position, hidden_point_removal_radius=200 + ): + point_visibility = torch.zeros((points.shape[0]), device=points.device).bool() + + pcd = o3d.geometry.PointCloud( + points=o3d.utility.Vector3dVector(points.cpu().numpy()) + ) + camera_pose = ( + eye_position.get_camera_center().squeeze().cpu().numpy().astype(np.float64) + ) + # o3d_camera = [0, 0, diameter] + diameter = np.linalg.norm( + np.asarray(pcd.get_max_bound()) - np.asarray(pcd.get_min_bound()) + ) + radius = diameter * 200 # The radius of the sperical projection + _, pt_map = pcd.hidden_point_removal(camera_pose, radius) + + visible_point_ids = np.array(pt_map) + + point_visibility[visible_point_ids] = True + return point_visibility + + @torch.no_grad() + def hidden_judge(self, camera, texture_dim): + mesh = self.mesh + + verts = mesh.verts_packed() + faces = mesh.faces_packed() + verts_uv = mesh.textures.verts_uvs_padded()[0] # 获取打包后的 UV 坐标 (V, 2) + faces_uv = mesh.textures.faces_uvs_padded()[0] + uv_face_attr = torch.index_select( + verts_uv, 0, faces_uv.view(-1) + ) # 选择对应顶点的 UV 坐标 + uv_face_attr = uv_face_attr.view( + faces.shape[0], faces_uv.shape[1], 2 + ).unsqueeze(0) + x, y, z = verts[:, 0], verts[:, 1], verts[:, 2] + mesh_out_of_range = False + if ( + x.min() < -1 + or x.max() > 1 + or y.min() < -1 + or y.max() > 1 + or z.min() < -1 + or z.max() > 1 + ): + mesh_out_of_range = True + face_vertices_world = kal.ops.mesh.index_vertices_by_faces( + verts.unsqueeze(0), faces + ) + face_vertices_z = torch.zeros_like( + face_vertices_world[:, :, :, -1], device=verts.device + ) + uv_position, face_idx = kal.render.mesh.rasterize( + texture_dim, + texture_dim, + face_vertices_z, + uv_face_attr * 2 - 1, + face_features=face_vertices_world, + ) + uv_position = torch.clamp(uv_position, -1, 1) + uv_position[face_idx == -1] = 0 + + points = uv_position.reshape(-1, 3) + mask = points[:, 0] != 0 + valid_points = points[mask] + # np.save("tmp/pcd.npy", valid_points.cpu().numpy()) + # print(camera.get_camera_center()) + + points_visibility = self.get_point_validation_by_o3d( + valid_points, camera + ).float() + visibility_map = torch.zeros((texture_dim * texture_dim,)).to(self.device) + visibility_map[mask] = points_visibility + visibility_map = visibility_map.reshape((texture_dim, texture_dim)) + return visibility_map + + @torch.enable_grad() + def bake_texture( + self, + views=None, + main_views=[], + cos_weighted=True, + channels=None, + exp=None, + noisy=False, + generator=None, + smooth_colorize=False, + ): + if not exp: + exp = 1 + if not channels: + channels = self.channels + views = [view.permute(1, 2, 0) for view in views] + + tmp_mesh = self.mesh + bake_maps = [ + torch.zeros( + self.target_size + (views[0].shape[2],), + device=self.device, + requires_grad=True, + ) + for view in views + ] + optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0) + optimizer.zero_grad() + loss = 0 + for i in range(len(self.cameras)): + bake_tex = TexturesUV( + [bake_maps[i]], + tmp_mesh.textures.faces_uvs_padded(), + tmp_mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + tmp_mesh.textures = bake_tex + images_predicted = self.renderer( + tmp_mesh, + cameras=self.cameras[i], + lights=self.lights, + device=self.device, + ) + predicted_rgb = images_predicted[..., :-1] + loss += (((predicted_rgb[...] - views[i])) ** 2).sum() + loss.backward(retain_graph=False) + optimizer.step() + + total_weights = 0 + baked = 0 + for i in range(len(bake_maps)): + normalized_baked_map = bake_maps[i].detach() / ( + self.gradient_maps[i] + 1e-8 + ) + bake_map = voronoi_solve( + normalized_baked_map, self.gradient_maps[i][..., 0], self.device + ) + # bake_map = voronoi_solve(normalized_baked_map, self.visible_triangles[i].squeeze()) + + weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp + if smooth_colorize: + visibility_map = self.hidden_judge( + self.cameras[i], self.target_size[0] + ).unsqueeze(-1) + weight *= visibility_map + if noisy: + noise = ( + torch.rand(weight.shape[:-1] + (1,), generator=generator) + .type(weight.dtype) + .to(weight.device) + ) + weight *= noise + total_weights += weight + + baked += bake_map * weight + baked /= total_weights + 1e-8 + + whole_visible_mask = None + if not smooth_colorize: + baked = voronoi_solve(baked, total_weights[..., 0], self.device) + tmp_mesh.textures = TexturesUV( + [baked], + tmp_mesh.textures.faces_uvs_padded(), + tmp_mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + else: # smooth colorize + baked = voronoi_solve(baked, total_weights[..., 0], self.device) + whole_visible_mask = self.visible_triangles[0].to(torch.int32) + for tensor in self.visible_triangles[1:]: + whole_visible_mask = torch.bitwise_or( + whole_visible_mask, tensor.to(torch.int32) + ) + + baked *= whole_visible_mask + tmp_mesh.textures = TexturesUV( + [baked], + tmp_mesh.textures.faces_uvs_padded(), + tmp_mesh.textures.verts_uvs_padded(), + sampling_mode=self.sampling_mode, + ) + + extended_mesh = tmp_mesh.extend(len(self.cameras)) + images_predicted = self.renderer( + extended_mesh, cameras=self.cameras, lights=self.lights + ) + learned_views = [image.permute(2, 0, 1) for image in images_predicted] + + return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1) + + # Move the internel data to a specific device + def to(self, device): + for mesh_name in ["mesh", "mesh_d", "mesh_uv"]: + if hasattr(self, mesh_name): + mesh = getattr(self, mesh_name) + setattr(self, mesh_name, mesh.to(device)) + for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]: + if hasattr(self, list_name): + map_list = getattr(self, list_name) + for i in range(len(map_list)): + map_list[i] = map_list[i].to(device) diff --git a/step1x3d_texture/renderer/shader.py b/step1x3d_texture/renderer/shader.py new file mode 100644 index 0000000000000000000000000000000000000000..e46dc1656dc9e5070c4938868e6b0b6304f1f878 --- /dev/null +++ b/step1x3d_texture/renderer/shader.py @@ -0,0 +1,127 @@ +from typing import Optional + +import torch +import pytorch3d + + +from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj +from pytorch3d.ops import interpolate_face_attributes + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + AmbientLights, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + SoftSilhouetteShader, + HardPhongShader, + TexturesVertex, + TexturesUV, + Materials, +) +from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend +from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties + +from pytorch3d.renderer.lighting import AmbientLights +from pytorch3d.renderer.materials import Materials +from pytorch3d.renderer.mesh.shader import ShaderBase +from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading +from pytorch3d.renderer.mesh.rasterizer import Fragments + + +""" + Customized the original pytorch3d hard flat shader to support N channel flat shading +""" + + +class HardNChannelFlatShader(ShaderBase): + """ + Per face lighting - the lighting model is applied using the average face + position and the face normal. The blending function hard assigns + the color of the closest face for each pixel. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = HardFlatShader(device=torch.device("cuda:0")) + """ + + def __init__( + self, + device="cpu", + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, + channels: int = 3, + ): + self.channels = channels + ones = ((1.0,) * channels,) + zeros = ((0.0,) * channels,) + + if ( + not isinstance(lights, AmbientLights) + or not lights.ambient_color.shape[-1] == channels + ): + lights = AmbientLights( + ambient_color=ones, + device=device, + ) + + if not materials or not materials.ambient_color.shape[-1] == channels: + materials = Materials( + device=device, + diffuse_color=zeros, + ambient_color=ones, + specular_color=zeros, + shininess=0.0, + ) + + blend_params_new = BlendParams(background_color=(1.0,) * channels) + if not isinstance(blend_params, BlendParams): + blend_params = blend_params_new + else: + background_color_ = blend_params.background_color + if ( + isinstance(background_color_, Sequence[float]) + and not len(background_color_) == channels + ): + blend_params = blend_params_new + if ( + isinstance(background_color_, torch.Tensor) + and not background_color_.shape[-1] == channels + ): + blend_params = blend_params_new + + super().__init__( + device, + cameras, + lights, + materials, + blend_params, + ) + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = super()._get_cameras(**kwargs) + texels = meshes.sample_textures(fragments) + lights = kwargs.get("lights", self.lights) + materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) + colors = flat_shading( + meshes=meshes, + fragments=fragments, + texels=texels, + lights=lights, + cameras=cameras, + materials=materials, + ) + images = hard_rgb_blend(colors, fragments, blend_params) + return images diff --git a/step1x3d_texture/renderer/voronoi.py b/step1x3d_texture/renderer/voronoi.py new file mode 100644 index 0000000000000000000000000000000000000000..745605a0647ee878fe2ed689a4ab3dfdfc3c32e2 --- /dev/null +++ b/step1x3d_texture/renderer/voronoi.py @@ -0,0 +1,172 @@ +""" +Program to compute Voronoi diagram using JFA. + +@author yisiox +@version September 2022 +""" + +import cupy as cp +from random import sample + +# global variables +x_dim = 512 +y_dim = 512 +noSeeds = 1024 + +# diagram is represented as a 2d array where each element is +# x coord of source * y_dim + y coord of source +ping = cp.full((x_dim, y_dim), -1, dtype=int) +pong = None + + +import torch +import time + + +def process_tensors(tensor1, tensor2): + # start_time = time.time() + + tensor2_unique = torch.unique(tensor2) + mask = torch.isin(tensor1, tensor2_unique, assume_unique=True) + tensor1[~mask] = -1 + + end_time = time.time() + # print(f"Computation time: {end_time - start_time} seconds") + + return tensor1 + + +def test_performance(): + computation_times = [] + + for _ in range(10): + tensor1 = torch.randint(0, 40001, (1024, 1024)).cuda() + tensor2 = torch.randint(5000, 15000, (512, 512)).cuda() + + process_tensors(tensor1, tensor2) + + +def voronoi_solve(texture, mask, device="cuda"): + """ + This is a warpper of the original cupy voronoi implementation + The texture color where mask value is 1 will propagate to its + neighbors. + args: + texture - A multi-channel tensor, (H, W, C) + mask - A single-channel tensor, (H, W) + return: + texture - Propagated tensor + """ + h, w, c = texture.shape + # hwc_texture = texture.permute(1,2,0) + valid_pix_coord = torch.where(mask > 0) + + indices = torch.arange(0, h * w).to(device).reshape(h, w) + idx_map = -1 * torch.ones((h, w), dtype=torch.int64).to(device) + idx_map[valid_pix_coord] = indices[valid_pix_coord] + + ping = cp.asarray(idx_map) + pong = cp.copy(ping) + ping = JFAVoronoiDiagram(ping, pong) + + voronoi_map = torch.as_tensor(ping, device=device) + nc_voronoi_texture = torch.index_select( + texture.reshape(h * w, c), 0, voronoi_map.reshape(h * w) + ) + voronoi_texture = nc_voronoi_texture.reshape(h, w, c) + + return voronoi_texture + + +def generateRandomSeeds(n): + """ + Function to generate n random seeds. + + @param n The number of seeds to generate. + """ + global ping, pong + + if n > x_dim * y_dim: + print("Error: Number of seeds greater than number of pixels.") + return + + # take sample of cartesian product + coords = [(x, y) for x in range(x_dim) for y in range(y_dim)] + seeds = sample(coords, n) + for i in range(n): + x, y = seeds[i] + ping[x, y] = x * y_dim + y + pong = cp.copy(ping) + + +displayKernel = cp.ElementwiseKernel( + "int64 x", "int64 y", f"y = (x < 0) ? x : x % 103", "displayTransform" +) + + +voronoiKernel = cp.RawKernel( + r""" + extern "C" __global__ + void voronoiPass(const long long step, const long long xDim, const long long yDim, const long long *ping, long long *pong) { + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + long long stp = blockDim.x * gridDim.x; + + for (long long k = idx; k < xDim * yDim; k += stp) { + long long dydx[] = {-1, 0, 1}; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + long long dx = (step * dydx[i]) * yDim; + long long dy = step * dydx[j]; + long long src = k + dx + dy; + if (src < 0 || src >= xDim * yDim) + continue; + if (ping[src] == -1) + continue; + if (pong[k] == -1) { + pong[k] = ping[src]; + continue; + } + long long x1 = k / yDim; + long long y1 = k % yDim; + long long x2 = pong[k] / yDim; + long long y2 = pong[k] % yDim; + long long x3 = ping[src] / yDim; + long long y3 = ping[src] % yDim; + long long curr_dist = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2); + long long jump_dist = (x1 - x3) * (x1 - x3) + (y1 - y3) * (y1 - y3); + if (jump_dist < curr_dist) + pong[k] = ping[src]; + } + } + } + } + """, + "voronoiPass", +) + + +""" + + y and x is actually w and h? (according to experiment result) + +""" + + +def JFAVoronoiDiagram(ping, pong): + # global ping, pong + # compute initial step size + x_dim, y_dim = ping.shape + step = max(x_dim, y_dim) // 2 + # initalise frame number and display original state + frame = 0 + # iterate while step size is greater than 0 + while step: + voronoiKernel( + (min(x_dim, 512),), (min(y_dim, 512),), (step, x_dim, y_dim, ping, pong) + ) + # Ajusted the upper bound of the kernel dimension from 1024 to 512 to avoid CUDA OUT OF RESOURCE problem + ping, pong = pong, ping + frame += 1 + step //= 2 + # displayDiagram(frame, ping) + return ping diff --git a/step1x3d_texture/schedulers/scheduler_utils.py b/step1x3d_texture/schedulers/scheduler_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d25a55eed58eb33c2932a4d8e7b8c424a7a7dcb5 --- /dev/null +++ b/step1x3d_texture/schedulers/scheduler_utils.py @@ -0,0 +1,70 @@ +import torch + + +def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def SNR_to_betas(snr): + """ + Converts SNR to betas + """ + # alphas_cumprod = pass + # snr = (alpha / ) ** 2 + # alpha_t^2 / (1 - alpha_t^2) = snr + alpha_t = (snr / (1 + snr)) ** 0.5 + alphas_cumprod = alpha_t**2 + alphas = alphas_cumprod / torch.cat( + [torch.ones(1, device=snr.device), alphas_cumprod[:-1]] + ) + betas = 1 - alphas + return betas + + +def compute_snr(timesteps, noise_scheduler): + """ + Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def compute_alpha(timesteps, noise_scheduler): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + return alpha diff --git a/step1x3d_texture/schedulers/scheduling_shift_snr.py b/step1x3d_texture/schedulers/scheduling_shift_snr.py new file mode 100644 index 0000000000000000000000000000000000000000..6da615031bd3a3239f88d184083b2746f2ffab42 --- /dev/null +++ b/step1x3d_texture/schedulers/scheduling_shift_snr.py @@ -0,0 +1,138 @@ +from typing import Any + +import torch + +from .scheduler_utils import SNR_to_betas, compute_snr + + +class ShiftSNRScheduler: + def __init__( + self, + noise_scheduler: Any, + timesteps: Any, + shift_scale: float, + scheduler_class: Any, + ): + self.noise_scheduler = noise_scheduler + self.timesteps = timesteps + self.shift_scale = shift_scale + self.scheduler_class = scheduler_class + + def _get_shift_scheduler(self): + """ + Prepare scheduler for shifted betas. + + :return: A scheduler object configured with shifted betas + """ + snr = compute_snr(self.timesteps, self.noise_scheduler) + shifted_betas = SNR_to_betas(snr / self.shift_scale) + + return self.scheduler_class.from_config( + self.noise_scheduler.config, trained_betas=shifted_betas.numpy() + ) + + def _get_interpolated_shift_scheduler(self): + """ + Prepare scheduler for shifted betas and interpolate with the original betas in log space. + + :return: A scheduler object configured with interpolated shifted betas + """ + snr = compute_snr(self.timesteps, self.noise_scheduler) + shifted_snr = snr / self.shift_scale + + weighting = self.timesteps.float() / ( + self.noise_scheduler.config.num_train_timesteps - 1 + ) + interpolated_snr = torch.exp( + torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting + ) + + shifted_betas = SNR_to_betas(interpolated_snr) + + return self.scheduler_class.from_config( + self.noise_scheduler.config, trained_betas=shifted_betas.numpy() + ) + + @classmethod + def from_scheduler( + cls, + noise_scheduler: Any, + shift_mode: str = "default", + timesteps: Any = None, + shift_scale: float = 1.0, + scheduler_class: Any = None, + ): + # Check input + if timesteps is None: + timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps) + if scheduler_class is None: + scheduler_class = noise_scheduler.__class__ + + # Create scheduler + shift_scheduler = cls( + noise_scheduler=noise_scheduler, + timesteps=timesteps, + shift_scale=shift_scale, + scheduler_class=scheduler_class, + ) + + if shift_mode == "default": + return shift_scheduler._get_shift_scheduler() + elif shift_mode == "interpolated": + return shift_scheduler._get_interpolated_shift_scheduler() + else: + raise ValueError(f"Unknown shift_mode: {shift_mode}") + + +if __name__ == "__main__": + """ + Compare the alpha values for different noise schedulers. + """ + import matplotlib.pyplot as plt + from diffusers import DDPMScheduler + + from .scheduler_utils import compute_alpha + + # Base + timesteps = torch.arange(0, 1000) + noise_scheduler_base = DDPMScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler" + ) + alpha = compute_alpha(timesteps, noise_scheduler_base) + plt.plot(timesteps.numpy(), alpha.numpy(), label="Base") + + # Kolors + num_train_timesteps_ = 1100 + timesteps_ = torch.arange(0, num_train_timesteps_) + noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_} + noise_scheduler_kolors = DDPMScheduler.from_config( + noise_scheduler_base.config, **noise_kwargs + ) + alpha = compute_alpha(timesteps_, noise_scheduler_kolors) + plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors") + + # Shift betas + shift_scale = 8.0 + noise_scheduler_shift = ShiftSNRScheduler.from_scheduler( + noise_scheduler_base, shift_mode="default", shift_scale=shift_scale + ) + alpha = compute_alpha(timesteps, noise_scheduler_shift) + plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)") + + # Shift betas (interpolated) + noise_scheduler_inter = ShiftSNRScheduler.from_scheduler( + noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale + ) + alpha = compute_alpha(timesteps, noise_scheduler_inter) + plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)") + + # ZeroSNR + noise_scheduler = DDPMScheduler.from_config( + noise_scheduler_base.config, rescale_betas_zero_snr=True + ) + alpha = compute_alpha(timesteps, noise_scheduler) + plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR") + + plt.legend() + plt.grid() + plt.savefig("check_alpha.png") diff --git a/step1x3d_texture/systems/__init__.py b/step1x3d_texture/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/step1x3d_texture/systems/base.py b/step1x3d_texture/systems/base.py new file mode 100644 index 0000000000000000000000000000000000000000..566e98048841a4ee7410d280623d5cdffd6d678b --- /dev/null +++ b/step1x3d_texture/systems/base.py @@ -0,0 +1,262 @@ +import os +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +from ..utils.base import Updateable, update_end_if_possible, update_if_possible +from ..utils.config import parse_structured +from ..utils.core import debug, find, info, warn +from ..utils.misc import ( + C, + cleanup, + get_device, + get_rank, + load_module_weights, + show_vram_usage, +) +from ..utils.saving import SaverMixin +from ..utils.typing import * +from .utils import parse_optimizer, parse_scheduler + + +class BaseSystem(pl.LightningModule, Updateable, SaverMixin): + @dataclass + class Config: + optimizer: dict = field(default_factory=dict) + scheduler: Optional[dict] = None + weights: Optional[str] = None + weights_ignore_modules: Optional[List[str]] = None + weights_mapping: Optional[List[Dict[str, str]]] = None + check_train_every_n_steps: int = 0 + check_val_limit_rank: int = 8 + cleanup_after_validation_step: bool = False + cleanup_after_test_step: bool = False + allow_tf32: bool = True + + cfg: Config + + def __init__(self, cfg, resumed=False) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self._save_dir: Optional[str] = None + self._resumed: bool = resumed + self._resumed_eval: bool = False + self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} + + # weird fix for extra VRAM usage on rank 0 + # credit: https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 + torch.cuda.set_device(get_rank()) + torch.cuda.empty_cache() + + torch.backends.cuda.matmul.allow_tf32 = self.cfg.allow_tf32 + + self.configure() + if self.cfg.weights is not None: + self.load_weights( + self.cfg.weights, + self.cfg.weights_ignore_modules, + self.cfg.weights_mapping, + ) + self.post_configure() + + def load_weights( + self, + weights: str, + ignore_modules: Optional[List[str]] = None, + mapping: Optional[List[Dict[str, str]]] = None, + ): + state_dict, epoch, global_step = load_module_weights( + weights, + ignore_modules=ignore_modules, + mapping=mapping, + map_location="cpu", + ) + self.load_state_dict(state_dict, strict=False) + # restore step-dependent states + self.do_update_step(epoch, global_step, on_load_weights=True) + + def set_resume_status(self, current_epoch: int, global_step: int): + # restore correct epoch and global step in eval + self._resumed_eval = True + self._resumed_eval_status["current_epoch"] = current_epoch + self._resumed_eval_status["global_step"] = global_step + + @property + def resumed(self): + # whether from resumed checkpoint + return self._resumed + + @property + def true_global_step(self): + if self._resumed_eval: + return self._resumed_eval_status["global_step"] + else: + return self.global_step + + @property + def true_current_epoch(self): + if self._resumed_eval: + return self._resumed_eval_status["current_epoch"] + else: + return self.current_epoch + + def configure(self) -> None: + pass + + def post_configure(self) -> None: + """ + executed after weights are loaded + """ + pass + + def C(self, value: Any) -> float: + return C(value, self.true_current_epoch, self.true_global_step) + + def configure_optimizers(self): + optim = parse_optimizer(self.cfg.optimizer, self) + ret = { + "optimizer": optim, + } + if self.cfg.scheduler is not None: + ret.update( + { + "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), + } + ) + return ret + + def on_fit_start(self) -> None: + if self._save_dir is not None: + info(f"Validation results will be saved to {self._save_dir}") + else: + warn( + f"Saving directory not set for the system, visualization results will not be saved" + ) + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + def check_train(self, batch, **kwargs): + if ( + self.global_rank == 0 + and self.cfg.check_train_every_n_steps > 0 + and self.true_global_step % self.cfg.check_train_every_n_steps == 0 + ): + self.on_check_train(batch, **kwargs) + + def on_check_train(self, batch, outputs, **kwargs): + pass + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def on_test_epoch_end(self): + pass + + def on_test_end(self) -> None: + if self._save_dir is not None: + info(f"Test results saved to {self._save_dir}") + + def on_predict_start(self) -> None: + pass + + def predict_step(self, batch, batch_idx): + pass + + def on_predict_epoch_end(self) -> None: + pass + + def on_predict_end(self) -> None: + pass + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.preprocess_data(batch, "train") + self.dataset = self.trainer.train_dataloader.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "validation") + self.dataset = self.trainer.val_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "test") + self.dataset = self.trainer.test_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "predict") + self.dataset = self.trainer.predict_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_train_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.train_dataloader.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.val_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_validation_step: + # cleanup to save vram + cleanup() + + def on_test_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.test_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_predict_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.predict_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + pass + + def on_before_optimizer_step(self, optimizer): + """ + # some gradient-related debugging goes here, example: + from lightning.pytorch.utilities import grad_norm + norms = grad_norm(self.geometry, norm_type=2) + print(norms) + for name, p in self.named_parameters(): + if p.grad is None: + info(f"{name} does not receive gradients!") + """ + pass diff --git a/step1x3d_texture/systems/ig2mv_sdxl.py b/step1x3d_texture/systems/ig2mv_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..1914eb7deebb36e5332183ed2ec3e24d7a0d54c2 --- /dev/null +++ b/step1x3d_texture/systems/ig2mv_sdxl.py @@ -0,0 +1,506 @@ +import os +import random +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from diffusers import DDPMScheduler, UNet2DConditionModel +from diffusers.models import AutoencoderKL +from diffusers.training_utils import compute_snr +from einops import rearrange +from omegaconf import OmegaConf +from PIL import Image + +from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline +from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler +from ..utils.core import find +from ..utils.typing import * +from .base import BaseSystem +from .utils import encode_prompt, vae_encode + + +def compute_embeddings( + prompt_batch, + empty_prompt_indices, + text_encoders, + tokenizers, + is_train=True, + **kwargs, +): + original_size = kwargs["original_size"] + target_size = kwargs["target_size"] + crops_coords_top_left = kwargs["crops_coords_top_left"] + + for i in range(empty_prompt_indices.shape[0]): + if empty_prompt_indices[i]: + prompt_batch[i] = "" + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, 0, is_train + ) + add_text_embeds = pooled_prompt_embeds.to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + +class IG2MVSDXLSystem(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + + # Model / Adapter + pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0" + pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix" + pretrained_adapter_name_or_path: Optional[str] = None + pretrained_unet_name_or_path: Optional[str] = None + init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict) + + use_fp16_vae: bool = True + use_fp16_clip: bool = True + + # Training + trainable_modules: List[str] = field(default_factory=list) + train_cond_encoder: bool = True + prompt_drop_prob: float = 0.0 + image_drop_prob: float = 0.0 + cond_drop_prob: float = 0.0 + + gradient_checkpointing: bool = False + + # Noise sampler + noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) + noise_offset: float = 0.0 + input_perturbation: float = 0.0 + snr_gamma: Optional[float] = 5.0 + prediction_type: Optional[str] = None + shift_noise: bool = False + shift_noise_mode: str = "interpolated" + shift_noise_scale: float = 1.0 + + # Evaluation + eval_seed: int = 0 + eval_num_inference_steps: int = 30 + eval_guidance_scale: float = 1.0 + eval_height: int = 512 + eval_width: int = 512 + + cfg: Config + + def configure(self): + super().configure() + + # Prepare pipeline + pipeline_kwargs = {} + if self.cfg.pretrained_vae_name_or_path is not None: + pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained( + self.cfg.pretrained_vae_name_or_path + ) + if self.cfg.pretrained_unet_name_or_path is not None: + pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained( + self.cfg.pretrained_unet_name_or_path + ) + + pipeline: IG2MVSDXLPipeline + pipeline = IG2MVSDXLPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, **pipeline_kwargs + ) + + init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs) + if "self_attn_processor" in init_adapter_kwargs: + self_attn_processor = init_adapter_kwargs["self_attn_processor"] + if self_attn_processor is not None and isinstance(self_attn_processor, str): + self_attn_processor = find(self_attn_processor) + init_adapter_kwargs["self_attn_processor"] = self_attn_processor + pipeline.init_custom_adapter(**init_adapter_kwargs) + + if self.cfg.pretrained_adapter_name_or_path: + pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path) + adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path) + pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name) + + noise_scheduler = DDPMScheduler.from_config( + pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs + ) + if self.cfg.shift_noise: + noise_scheduler = ShiftSNRScheduler.from_scheduler( + noise_scheduler, + shift_mode=self.cfg.shift_noise_mode, + shift_scale=self.cfg.shift_noise_scale, + scheduler_class=DDPMScheduler, + ) + pipeline.scheduler = noise_scheduler + + # Prepare models + self.pipeline: IG2MVSDXLPipeline = pipeline + self.vae = self.pipeline.vae.to( + dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32 + ) + self.tokenizer = self.pipeline.tokenizer + self.tokenizer_2 = self.pipeline.tokenizer_2 + self.text_encoder = self.pipeline.text_encoder.to( + dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 + ) + self.text_encoder_2 = self.pipeline.text_encoder_2.to( + dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 + ) + self.feature_extractor = self.pipeline.feature_extractor + + self.cond_encoder = self.pipeline.cond_encoder + self.unet = self.pipeline.unet + self.noise_scheduler = self.pipeline.scheduler + self.inference_scheduler = DDPMScheduler.from_config( + self.noise_scheduler.config + ) + self.pipeline.scheduler = self.inference_scheduler + if self.cfg.prediction_type is not None: + self.noise_scheduler.register_to_config( + prediction_type=self.cfg.prediction_type + ) + + # Prepare trainable / non-trainable modules + trainable_modules = self.cfg.trainable_modules + if trainable_modules and len(trainable_modules) > 0: + self.unet.requires_grad_(False) + for name, module in self.unet.named_modules(): + for trainable_module in trainable_modules: + if trainable_module in name: + module.requires_grad_(True) + else: + self.unet.requires_grad_(True) + self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder) + + self.vae.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.text_encoder_2.requires_grad_(False) + + # Others + # Prepare gradient checkpointing + if self.cfg.gradient_checkpointing: + self.unet.enable_gradient_checkpointing() + + def forward( + self, + noisy_latents: Tensor, + conditioning_pixel_values: Tensor, + timesteps: Tensor, + ref_latents: Tensor, + prompts: List[str], + num_views: int, + **kwargs, + ) -> Dict[str, Any]: + bsz = noisy_latents.shape[0] + b_samples = bsz // num_views + num_batch_images = num_views + + prompt_drop_mask = ( + torch.rand(b_samples, device=noisy_latents.device) + < self.cfg.prompt_drop_prob + ) + image_drop_mask = ( + torch.rand(b_samples, device=noisy_latents.device) + < self.cfg.image_drop_prob + ) + cond_drop_mask = ( + torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob + ) + prompt_drop_mask = prompt_drop_mask | cond_drop_mask + image_drop_mask = image_drop_mask | cond_drop_mask + + with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + additional_embeds = compute_embeddings( + prompts, + prompt_drop_mask, + [self.text_encoder, self.text_encoder_2], + [self.tokenizer, self.tokenizer_2], + **kwargs, + ) + + # Process reference latents to obtain reference features + with torch.no_grad(): + ref_timesteps = torch.zeros_like(timesteps[:b_samples]) + ref_hidden_states = {} + self.unet( + ref_latents, + ref_timesteps, + encoder_hidden_states=additional_embeds["prompt_embeds"], + added_cond_kwargs={ + "text_embeds": additional_embeds["text_embeds"], + "time_ids": additional_embeds["time_ids"], + }, + cross_attention_kwargs={ + "cache_hidden_states": ref_hidden_states, + "use_mv": False, + "use_ref": False, + }, + return_dict=False, + ) + for k, v in ref_hidden_states.items(): + v_ = v + v_[image_drop_mask] = 0.0 + ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0) + + # Repeat additional embeddings for each image in the batch + for key, value in additional_embeds.items(): + kwargs[key] = value.repeat_interleave(num_batch_images, dim=0) + + conditioning_features = self.cond_encoder(conditioning_pixel_values) + + added_cond_kwargs = { + "text_embeds": kwargs["text_embeds"], + "time_ids": kwargs["time_ids"], + } + + noise_pred = self.unet( + noisy_latents, + timesteps, + encoder_hidden_states=kwargs["prompt_embeds"], + added_cond_kwargs=added_cond_kwargs, + down_intrablock_additional_residuals=conditioning_features, + cross_attention_kwargs={ + "ref_hidden_states": ref_hidden_states, + "num_views": num_views, + }, + ).sample + + return {"noise_pred": noise_pred} + + def training_step(self, batch, batch_idx): + num_views = batch["num_views"] + + vae_max_slice = 8 + with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): + latents = [] + for i in range(0, batch["rgb"].shape[0], vae_max_slice): + latents.append( + vae_encode( + self.vae, + batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1, + sample=True, + apply_scale=True, + ).float() + ) + latents = torch.cat(latents, dim=0) + + with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): + ref_latents = vae_encode( + self.vae, + batch["reference_rgb"].to(self.vae.dtype) * 2 - 1, + sample=True, + apply_scale=True, + ).float() + + bsz = latents.shape[0] + b_samples = bsz // num_views + + noise = torch.randn_like(latents) + if self.cfg.noise_offset is not None: + # # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += self.cfg.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + + noise_mask = ( + batch["noise_mask"] + if "noise_mask" in batch + else torch.ones((bsz,), dtype=torch.bool, device=latents.device) + ) + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (b_samples,), + device=latents.device, + dtype=torch.long, + ) + timesteps = timesteps.repeat_interleave(num_views) + timesteps[~noise_mask] = 0 + + if self.cfg.input_perturbation is not None: + new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise) + noisy_latents = self.noise_scheduler.add_noise( + latents, new_noise, timesteps + ) + else: + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + noisy_latents[~noise_mask] = latents[~noise_mask] + + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}" + ) + + model_pred = self( + noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch + )["noise_pred"] + + model_pred = model_pred[noise_mask] + target = target[noise_mask] + + if self.cfg.snr_gamma is None: + loss = F.mse_loss(model_pred, target, reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(self.noise_scheduler, timesteps) + if self.noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + + loss = F.mse_loss(model_pred, target, reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + self.log("train/loss", loss, prog_bar=True) + + # will execute self.on_check_train every self.cfg.check_train_every_n_steps steps + self.check_train(batch) + + return {"loss": loss} + + def on_train_batch_end(self, outputs, batch, batch_idx): + pass + + def get_input_visualizations(self, batch): + return [ + { + "type": "rgb", + "img": rearrange( + batch["source_rgb"], + "(B N) C H W -> (B H) (N W) C", + N=batch["num_views"], + ), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": rearrange( + batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] + ), + "kwargs": {"data_format": "HWC"}, + }, + ] + + def get_output_visualizations(self, batch, outputs): + images = [ + { + "type": "rgb", + "img": rearrange( + batch["source_rgb"], + "(B N) C H W -> (B H) (N W) C", + N=batch["num_views"], + ), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": rearrange( + batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] + ), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": rearrange( + outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] + ), + "kwargs": {"data_format": "HWC"}, + }, + ] + return images + + def generate_images(self, batch, **kwargs): + return self.pipeline( + prompt=batch["prompts"], + control_image=batch["source_rgb"], + num_images_per_prompt=batch["num_views"], + generator=torch.Generator(device=self.device).manual_seed( + self.cfg.eval_seed + ), + num_inference_steps=self.cfg.eval_num_inference_steps, + guidance_scale=self.cfg.eval_guidance_scale, + height=self.cfg.eval_height, + width=self.cfg.eval_width, + reference_image=batch["reference_rgb"], + output_type="pt", + ).images + + def on_save_checkpoint(self, checkpoint): + if self.global_rank == 0: + self.pipeline.save_custom_adapter( + os.path.dirname(self.get_save_dir()), + "step1x-3d-ig2v.safetensors", + safe_serialization=True, + include_keys=self.cfg.trainable_modules, + ) + + def on_check_train(self, batch): + self.save_image_grid( + f"it{self.true_global_step}-train.jpg", + self.get_input_visualizations(batch), + name="train_step_input", + step=self.true_global_step, + ) + + def validation_step(self, batch, batch_idx): + out = self.generate_images(batch) + + if ( + self.cfg.check_val_limit_rank > 0 + and self.global_rank < self.cfg.check_val_limit_rank + ): + self.save_image_grid( + f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", + self.get_output_visualizations(batch, out), + name=f"validation_step_output_{self.global_rank}_{batch_idx}", + step=self.true_global_step, + ) + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + out = self.generate_images(batch) + + self.save_image_grid( + f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg", + self.get_output_visualizations(batch, out), + name=f"test_step_output_{self.global_rank}_{batch_idx}", + step=self.true_global_step, + ) + + def on_test_end(self): + pass diff --git a/step1x3d_texture/systems/utils.py b/step1x3d_texture/systems/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95599a40bbe2db32ff00ef01050f2837954afa3d --- /dev/null +++ b/step1x3d_texture/systems/utils.py @@ -0,0 +1,178 @@ +import random + +import numpy as np +import torch +import torch.nn as nn +from diffusers import AutoencoderKL +from torch.optim import lr_scheduler + +from ..utils.core import debug, find, info, warn +from ..utils.typing import * + +"""Diffusers Model Utils""" + + +def vae_encode( + vae: AutoencoderKL, + pixel_values: Float[Tensor, "B 3 H W"], + sample: bool = True, + apply_scale: bool = True, +): + latent_dist = vae.encode(pixel_values).latent_dist + latents = latent_dist.sample() if sample else latent_dist.mode() + if apply_scale: + latents = latents * vae.config.scaling_factor + return latents + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True +): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +CLIP_INPUT_MEAN = torch.as_tensor( + [0.48145466, 0.4578275, 0.40821073], dtype=torch.float32 +)[None, :, None, None] +CLIP_INPUT_STD = torch.as_tensor( + [0.26862954, 0.26130258, 0.27577711], dtype=torch.float32 +)[None, :, None, None] + + +def normalize_image_for_clip(image: Float[Tensor, "B C H W"]): + return (image - CLIP_INPUT_MEAN.to(image)) / CLIP_INPUT_STD.to(image) + + +"""Training""" + + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split("."): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, "params"): + params = [ + {"params": get_parameters(model, name), "name": name, **args} + for name, args in config.params.items() + ] + debug(f"Specify optimizer params: {config.params}") + else: + params = model.parameters() + if config.name in ["FusedAdam"]: + import apex + + optim = getattr(apex.optimizers, config.name)(params, **config.args) + elif config.name in ["Adam8bit", "AdamW8bit"]: + import bitsandbytes as bnb + + optim = bnb.optim.Adam8bit(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler_to_instance(config, optimizer): + if config.name == "ChainedScheduler": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.ChainedScheduler(schedulers) + elif config.name == "Sequential": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=config.milestones + ) + else: + scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) + return scheduler + + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler diff --git a/step1x3d_texture/utils/__init__.py b/step1x3d_texture/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db8ca54e749eef587c0a18ebf961cd8aa15c4289 --- /dev/null +++ b/step1x3d_texture/utils/__init__.py @@ -0,0 +1,3 @@ +from .camera import get_camera, get_orthogonal_camera +from .geometry import get_plucker_embeds_from_cameras_ortho +from .saving import make_image_grid, tensor_to_image diff --git a/step1x3d_texture/utils/base.py b/step1x3d_texture/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f78adda249e57664ee2ee8b38b41caf5e80a780f --- /dev/null +++ b/step1x3d_texture/utils/base.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from .config import parse_structured +from .misc import get_device, load_module_weights +from .typing import * + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(nn.Module, Updateable): + @dataclass + class Config: + weights: Optional[str] = None + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self._non_modules = {} + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict, epoch, global_step = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict) + self.do_update_step( + epoch, global_step, on_load_weights=True + ) # restore states + + def configure(self, *args, **kwargs) -> None: + pass + + def register_non_module(self, name: str, module: nn.Module) -> None: + # non-modules won't be treated as model parameters + self._non_modules[name] = module + + def non_module(self, name: str): + return self._non_modules.get(name, None) diff --git a/step1x3d_texture/utils/callbacks.py b/step1x3d_texture/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..50cb08a514b606ebb4ed6ecc0ed8c2f3d985f3e5 --- /dev/null +++ b/step1x3d_texture/utils/callbacks.py @@ -0,0 +1,158 @@ +import os +import shutil +import subprocess + +import pytorch_lightning + +from .config import dump_config +from .misc import parse_version + +if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join( + self.save_root, + ( + self.version + if isinstance(self.version, str) + else f"version_{self.version}" + ), + ) + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() + for b in set( + subprocess.check_output( + 'git ls-files -- ":!:load/*"', shell=True + ).splitlines() + ) + | set( # hard code, TODO: use config to exclude folders or files + subprocess.check_output( + "git ls-files --others --exclude-standard", shell=True + ).splitlines() + ) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn( + "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." + ) + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config_path, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config_path = config_path + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) + shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items + + +class ProgressCallback(Callback): + def __init__(self, save_path): + super().__init__() + self.save_path = save_path + self._file_handle = None + + @property + def file_handle(self): + if self._file_handle is None: + self._file_handle = open(self.save_path, "w") + return self._file_handle + + @rank_zero_only + def write(self, msg: str) -> None: + self.file_handle.seek(0) + self.file_handle.truncate() + self.file_handle.write(msg) + self.file_handle.flush() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): + self.write( + f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" + ) + + @rank_zero_only + def on_validation_start(self, trainer, pl_module): + self.write(f"Rendering validation image ...") + + @rank_zero_only + def on_test_start(self, trainer, pl_module): + self.write(f"Rendering video ...") + + @rank_zero_only + def on_predict_start(self, trainer, pl_module): + self.write(f"Exporting mesh assets ...") diff --git a/step1x3d_texture/utils/camera.py b/step1x3d_texture/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..b4275ee85d75433d0fc880b6673edd65a6a02d42 --- /dev/null +++ b/step1x3d_texture/utils/camera.py @@ -0,0 +1,211 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +import trimesh +from PIL import Image +from torch import BoolTensor, FloatTensor + +LIST_TYPE = Union[list, np.ndarray, torch.Tensor] + + +def list_to_pt( + x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None +) -> torch.Tensor: + if isinstance(x, list) or isinstance(x, np.ndarray): + return torch.tensor(x, dtype=dtype, device=device) + return x.to(dtype=dtype) + + +def get_c2w( + elevation_deg: LIST_TYPE, + distance: LIST_TYPE, + azimuth_deg: Optional[LIST_TYPE], + num_views: Optional[int] = 1, + device: Optional[str] = None, +) -> torch.FloatTensor: + if azimuth_deg is None: + assert ( + num_views is not None + ), "num_views must be provided if azimuth_deg is None." + azimuth_deg = torch.linspace( + 0, 360, num_views + 1, dtype=torch.float32, device=device + )[:-1] + else: + num_views = len(azimuth_deg) + azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) + elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) + camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + camera_positions = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + center = torch.zeros_like(camera_positions) + up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat( + num_views, 1 + ) + lookat = F.normalize(center - camera_positions, dim=-1) + right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) + up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) + c2w3x4 = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) + c2w[:, 3, 3] = 1.0 + return c2w + + +def get_projection_matrix( + fovy_deg: LIST_TYPE, + aspect_wh: float = 1.0, + near: float = 0.1, + far: float = 100.0, + device: Optional[str] = None, +) -> torch.FloatTensor: + fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device) + batch_size = fovy_deg.shape[0] + fovy = fovy_deg * math.pi / 180 + tan_half_fovy = torch.tan(fovy / 2) + projection_matrix = torch.zeros( + batch_size, 4, 4, dtype=torch.float32, device=device + ) + projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy) + projection_matrix[:, 1, 1] = -1 / tan_half_fovy + projection_matrix[:, 2, 2] = -(far + near) / (far - near) + projection_matrix[:, 2, 3] = -2 * far * near / (far - near) + projection_matrix[:, 3, 2] = -1 + return projection_matrix + + +def get_orthogonal_projection_matrix( + batch_size: int, + left: float, + right: float, + bottom: float, + top: float, + near: float = 0.1, + far: float = 100.0, + device: Optional[str] = None, +) -> torch.FloatTensor: + projection_matrix = torch.zeros( + batch_size, 4, 4, dtype=torch.float32, device=device + ) + projection_matrix[:, 0, 0] = 2 / (right - left) + projection_matrix[:, 1, 1] = -2 / (top - bottom) + projection_matrix[:, 2, 2] = -2 / (far - near) + projection_matrix[:, 0, 3] = -(right + left) / (right - left) + projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom) + projection_matrix[:, 2, 3] = -(far + near) / (far - near) + projection_matrix[:, 3, 3] = 1 + return projection_matrix + + +@dataclass +class Camera: + c2w: Optional[torch.FloatTensor] + w2c: torch.FloatTensor + proj_mtx: torch.FloatTensor + mvp_mtx: torch.FloatTensor + cam_pos: Optional[torch.FloatTensor] + + def __getitem__(self, index): + if isinstance(index, int): + sl = slice(index, index + 1) + elif isinstance(index, slice): + sl = index + else: + raise NotImplementedError + + return Camera( + c2w=self.c2w[sl] if self.c2w is not None else None, + w2c=self.w2c[sl], + proj_mtx=self.proj_mtx[sl], + mvp_mtx=self.mvp_mtx[sl], + cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None, + ) + + def to(self, device: Optional[str] = None): + if self.c2w is not None: + self.c2w = self.c2w.to(device) + self.w2c = self.w2c.to(device) + self.proj_mtx = self.proj_mtx.to(device) + self.mvp_mtx = self.mvp_mtx.to(device) + if self.cam_pos is not None: + self.cam_pos = self.cam_pos.to(device) + + def __len__(self): + return self.c2w.shape[0] + + +def get_camera( + elevation_deg: Optional[LIST_TYPE] = None, + distance: Optional[LIST_TYPE] = None, + fovy_deg: Optional[LIST_TYPE] = None, + azimuth_deg: Optional[LIST_TYPE] = None, + num_views: Optional[int] = 1, + c2w: Optional[torch.FloatTensor] = None, + w2c: Optional[torch.FloatTensor] = None, + proj_mtx: Optional[torch.FloatTensor] = None, + aspect_wh: float = 1.0, + near: float = 0.1, + far: float = 100.0, + device: Optional[str] = None, +): + if w2c is None: + if c2w is None: + c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) + camera_positions = c2w[:, :3, 3] + w2c = torch.linalg.inv(c2w) + else: + camera_positions = None + c2w = None + if proj_mtx is None: + proj_mtx = get_projection_matrix( + fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device + ) + mvp_mtx = proj_mtx @ w2c + return Camera( + c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions + ) + + +def get_orthogonal_camera( + elevation_deg: LIST_TYPE, + distance: LIST_TYPE, + left: float, + right: float, + bottom: float, + top: float, + azimuth_deg: Optional[LIST_TYPE] = None, + num_views: Optional[int] = 1, + near: float = 0.1, + far: float = 100.0, + device: Optional[str] = None, +): + c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) + camera_positions = c2w[:, :3, 3] + w2c = torch.linalg.inv(c2w) + proj_mtx = get_orthogonal_projection_matrix( + batch_size=c2w.shape[0], + left=left, + right=right, + bottom=bottom, + top=top, + near=near, + far=far, + device=device, + ) + mvp_mtx = proj_mtx @ w2c + return Camera( + c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions + ) diff --git a/step1x3d_texture/utils/config.py b/step1x3d_texture/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..17688355d17dc30b45d85263688f75c782b9dcb8 --- /dev/null +++ b/step1x3d_texture/utils/config.py @@ -0,0 +1,140 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +from omegaconf import OmegaConf + +from .core import debug, find, info, warn +from .typing import * + +# ============ Register OmegaConf Resolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) + + +def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8): + return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs + + +OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps) + +# ======================================================= # + + +# ============== Automatic Name Resolvers =============== # +def get_naming_convention(cfg): + # TODO + name = f"lrm_{cfg.system.backbone.num_layers}" + return name + + +# ======================================================= # + + +@dataclass +class ExperimentConfig: + name: str = "default" + description: str = "" + tag: str = "" + seed: int = 0 + use_timestamp: bool = True + timestamp: Optional[str] = None + exp_root_dir: str = "outputs" + + ### these shouldn't be set manually + exp_dir: str = "outputs/default" + trial_name: str = "exp" + trial_dir: str = "outputs/default/exp" + n_gpus: int = 1 + ### + + resume: Optional[str] = None + + data_cls: str = "" + data: dict = field(default_factory=dict) + + system_cls: str = "" + system: dict = field(default_factory=dict) + + # accept pytorch-lightning trainer parameters + # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api + trainer: dict = field(default_factory=dict) + + # accept pytorch-lightning checkpoint callback parameters + # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint + checkpoint: dict = field(default_factory=dict) + + +def load_config( + *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs +) -> Any: + if from_string: + parse_func = OmegaConf.create + else: + parse_func = OmegaConf.load + yaml_confs = [] + for y in yamls: + conf = parse_func(y) + extends = conf.pop("extends", None) + if extends: + assert os.path.exists(extends), f"File {extends} does not exist." + yaml_confs.append(OmegaConf.load(extends)) + yaml_confs.append(conf) + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) + + # post processing + # auto naming + if scfg.name == "auto": + scfg.name = get_naming_convention(scfg) + # add timestamp + if not scfg.tag and not scfg.use_timestamp: + raise ValueError("Either tag is specified or use_timestamp is True.") + scfg.trial_name = scfg.tag + # if resume from an existing config, scfg.timestamp should not be None + if scfg.timestamp is None: + scfg.timestamp = "" + if scfg.use_timestamp: + if scfg.n_gpus > 1: + warn( + "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." + ) + else: + scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") + # make directories + scfg.trial_name += scfg.timestamp + scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name) + scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name) + + if makedirs: + os.makedirs(scfg.trial_dir, exist_ok=True) + + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) + return scfg diff --git a/step1x3d_texture/utils/core.py b/step1x3d_texture/utils/core.py new file mode 100644 index 0000000000000000000000000000000000000000..7c26c6de3f3f3b3ffee6d0c691342de9f92ac69f --- /dev/null +++ b/step1x3d_texture/utils/core.py @@ -0,0 +1,29 @@ +import importlib + +### grammar sugar for logging utilities ### +import logging + +logger = logging.getLogger("pytorch_lightning") + +from pytorch_lightning.utilities.rank_zero import ( + rank_zero_debug, + rank_zero_info, + rank_zero_only, +) + + +def find(cls_string): + module_string = ".".join(cls_string.split(".")[:-1]) + cls_name = cls_string.split(".")[-1] + module = importlib.import_module(module_string, package=None) + cls = getattr(module, cls_name) + return cls + + +debug = rank_zero_debug +info = rank_zero_info + + +@rank_zero_only +def warn(*args, **kwargs): + logger.warn(*args, **kwargs) diff --git a/step1x3d_texture/utils/geometry.py b/step1x3d_texture/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae9196ab36bc9a29b1b2a02a4c78b7e7257db48 --- /dev/null +++ b/step1x3d_texture/utils/geometry.py @@ -0,0 +1,253 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch +from torch.nn import functional as F + + +def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None): + """Compute the position map from the depth map and the camera parameters for a batch of views. + + Args: + depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). + mask (torch.Tensor): The masks with the shape (B, H, W, 1). + intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3). + extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). + image_wh (Tuple[int, int]): The image width and height. + + Returns: + torch.Tensor: The position maps with the shape (B, H, W, 3). + """ + if image_wh is None: + image_wh = depth.shape[2], depth.shape[1] + + B, H, W, _ = depth.shape + depth = depth.squeeze(-1) + + u_coord, v_coord = torch.meshgrid( + torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy" + ) + u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) + v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) + + # Compute the position map by back-projecting depth pixels to 3D space + x = ( + (u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) + * depth + / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) + ) + y = ( + (v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) + * depth + / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) + ) + z = depth + + # Concatenate to form the 3D coordinates in the camera frame + camera_coords = torch.stack([x, y, z], dim=-1) + + # Apply the extrinsic matrix to get coordinates in the world frame + coords_homogeneous = torch.nn.functional.pad( + camera_coords, (0, 1), "constant", 1.0 + ) # Add a homogeneous coordinate + world_coords = torch.matmul( + coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) + ).view(B, H, W, 4) + + # Apply the mask to the position map + position_map = world_coords[..., :3] * mask + + return position_map + + +def get_position_map_from_depth_ortho( + depth, mask, extrinsics, ortho_scale, image_wh=None +): + """Compute the position map from the depth map and the camera parameters for a batch of views + using orthographic projection with a given ortho_scale. + + Args: + depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). + mask (torch.Tensor): The masks with the shape (B, H, W, 1). + extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). + ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1). + image_wh (Tuple[int, int]): Optional. The image width and height. + + Returns: + torch.Tensor: The position maps with the shape (B, H, W, 3). + """ + if image_wh is None: + image_wh = depth.shape[2], depth.shape[1] + + B, H, W, _ = depth.shape + depth = depth.squeeze(-1) + + # Generating grid of coordinates in the image space + u_coord, v_coord = torch.meshgrid( + torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy" + ) + u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) + v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) + + # Compute the position map using orthographic projection with ortho_scale + x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0] + y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1] + z = depth + + # Concatenate to form the 3D coordinates in the camera frame + camera_coords = torch.stack([x, y, z], dim=-1) + + # Apply the extrinsic matrix to get coordinates in the world frame + coords_homogeneous = torch.nn.functional.pad( + camera_coords, (0, 1), "constant", 1.0 + ) # Add a homogeneous coordinate + world_coords = torch.matmul( + coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) + ).view(B, H, W, 4) + + # Apply the mask to the position map + position_map = world_coords[..., :3] * mask + + return position_map + + +def get_opencv_from_blender(matrix_world, fov=None, image_size=None): + # convert matrix_world to opencv format extrinsics + opencv_world_to_cam = matrix_world.inverse() + opencv_world_to_cam[1, :] *= -1 + opencv_world_to_cam[2, :] *= -1 + R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3] + + if fov is None: # orthographic camera + return R, T + + R, T = R.unsqueeze(0), T.unsqueeze(0) + # convert fov to opencv format intrinsics + focal = 1 / np.tan(fov / 2) + intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32) + opencv_cam_matrix = ( + torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device) + ) + opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to( + matrix_world.device + ) + opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2 + + return R, T, opencv_cam_matrix + + +def get_ray_directions( + H: int, + W: int, + focal: float, + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, +) -> torch.Tensor: + """ + Get ray directions for all pixels in camera coordinate. + Args: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + cx, cy = W / 2, H / 2 if principal is None else principal + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + directions = torch.stack( + [(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1 + ) + return F.normalize(directions, dim=-1) + + +def get_rays( + directions: torch.Tensor, c2w: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get ray origins and directions from camera coordinates to world coordinates + Args: + directions: (H, W, 3) ray directions in camera coordinates + c2w: (4, 4) camera-to-world transformation matrix + Outputs: + rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates + """ + # Rotate ray directions from camera coordinate to the world coordinate + rays_d = directions @ c2w[:3, :3].T + rays_o = c2w[:3, 3].expand(rays_d.shape) + return rays_o, rays_d + + +def compute_plucker_embed( + c2w: torch.Tensor, image_width: int, image_height: int, focal: float +) -> torch.Tensor: + """ + Computes Plucker coordinates for a camera. + Args: + c2w: (4, 4) camera-to-world transformation matrix + image_width: Image width + image_height: Image height + focal: Focal length of the camera + Returns: + plucker: (6, H, W) Plucker embedding + """ + directions = get_ray_directions(image_height, image_width, focal) + rays_o, rays_d = get_rays(directions, c2w) + # Cross product to get Plucker coordinates + cross = torch.cross(rays_o, rays_d, dim=-1) + plucker = torch.cat((rays_d, cross), dim=-1) + return plucker.permute(2, 0, 1) + + +def get_plucker_embeds_from_cameras( + c2w: List[torch.Tensor], fov: List[float], image_size: int +) -> torch.Tensor: + """ + Given lists of camera transformations and fov, returns the batched plucker embeddings. + Args: + c2w: list of camera-to-world transformation matrices + fov: list of field of view values + image_size: size of the image + Returns: + plucker_embeds: (B, 6, H, W) batched plucker embeddings + """ + plucker_embeds = [] + for cam_matrix, cam_fov in zip(c2w, fov): + focal = 0.5 * image_size / np.tan(0.5 * cam_fov) + plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal) + plucker_embeds.append(plucker) + return torch.stack(plucker_embeds) + + +def get_plucker_embeds_from_cameras_ortho( + c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int +): + """ + Given lists of camera transformations and fov, returns the batched plucker embeddings. + + Parameters: + c2w: list of camera-to-world transformation matrices + fov: list of field of view values + image_size: size of the image + + Returns: + plucker_embeds: plucker embeddings (B, 6, H, W) + """ + plucker_embeds = [] + # compute pairwise mask and plucker embeddings + for cam_matrix, scale in zip(c2w, ortho_scale): + # blender to opencv to pytorch3d + R, T = get_opencv_from_blender(cam_matrix) + cam_pos = -R.T @ T + view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device) + # normalize camera position + cam_pos = F.normalize(cam_pos, dim=0) + plucker = torch.concat([view_dir, cam_pos]) + plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size) + plucker_embeds.append(plucker) + + plucker_embeds = torch.stack(plucker_embeds) + + return plucker_embeds diff --git a/step1x3d_texture/utils/logging.py b/step1x3d_texture/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..99fc47d2629d7a524d17c9ca28803e4a2abd8a83 --- /dev/null +++ b/step1x3d_texture/utils/logging.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2024 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Dict, Optional + +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.INFO + +_tqdm_active = True + + +def _get_default_logging_level() -> int: + """ + If LATEXTURE_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("LATEXTURE_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option LATEXTURE_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + enable_explicit_format() + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict() -> Dict[str, int]: + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an `int`. + + Returns: + `int`: + Logging level integers which can be one of: + + - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `40`: `diffusers.logging.ERROR` + - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `20`: `diffusers.logging.INFO` + - `10`: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level which can be one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info() -> None: + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning() -> None: + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug() -> None: + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error() -> None: + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter( + "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s" + ) + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for 🤗 Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs) -> None: + """ + This method is identical to `logger.warning()`, but if env var LATEXTURE_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("LATEXTURE_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar() -> None: + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar() -> None: + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/step1x3d_texture/utils/misc.py b/step1x3d_texture/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e29a9b7865641628576d39c469ec7396bc6d9f68 --- /dev/null +++ b/step1x3d_texture/utils/misc.py @@ -0,0 +1,220 @@ +import gc +import os +import re +import time +from collections import defaultdict +from contextlib import contextmanager + +import psutil +import torch +from packaging import version + +from .config import config_to_primitive +from .core import debug, find, info, warn +from .typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, mapping=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + + if mapping is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + if any([k.startswith(m["to"]) for m in mapping]): + pass + else: + state_dict_to_load[k] = v + for k, v in state_dict.items(): + for m in mapping: + if k.startswith(m["from"]): + k_dest = k.replace(m["from"], m["to"]) + info(f"Mapping {k} => {k_dest}") + state_dict_to_load[k_dest] = v.clone() + state_dict = state_dict_to_load + + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] + + +def C(value: Any, epoch: int, global_step: int) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = global_step + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + elif isinstance(end_step, float): + current_step = epoch + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + return value + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + try: + import tinycudann as tcnn + + tcnn.free_temporary_memory() + except: + pass + + +def finish_with_cleanup(func: Callable): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + cleanup() + return out + + return wrapper + + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def barrier(): + if not _distributed_available(): + return + else: + torch.distributed.barrier() + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + + +def enable_gradient(model, enabled: bool = True) -> None: + for param in model.parameters(): + param.requires_grad_(enabled) + + +class TimeRecorder: + _instance = None + + def __init__(self): + self.items = {} + self.accumulations = defaultdict(list) + self.time_scale = 1000.0 # ms + self.time_unit = "ms" + self.enabled = False + + def __new__(cls): + # singleton + if cls._instance is None: + cls._instance = super(TimeRecorder, cls).__new__(cls) + return cls._instance + + def enable(self, enabled: bool) -> None: + self.enabled = enabled + + def start(self, name: str) -> None: + if not self.enabled: + return + torch.cuda.synchronize() + self.items[name] = time.time() + + def end(self, name: str, accumulate: bool = False) -> float: + if not self.enabled or name not in self.items: + return + torch.cuda.synchronize() + start_time = self.items.pop(name) + delta = time.time() - start_time + if accumulate: + self.accumulations[name].append(delta) + t = delta * self.time_scale + info(f"{name}: {t:.2f}{self.time_unit}") + + def get_accumulation(self, name: str, average: bool = False) -> float: + if not self.enabled or name not in self.accumulations: + return + acc = self.accumulations.pop(name) + total = sum(acc) + if average: + t = total / len(acc) * self.time_scale + else: + t = total * self.time_scale + info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") + + +### global time recorder +time_recorder = TimeRecorder() + + +@contextmanager +def time_recorder_enabled(): + enabled = time_recorder.enabled + time_recorder.enable(enabled=True) + try: + yield + finally: + time_recorder.enable(enabled=enabled) + + +def show_vram_usage(name): + available, total = torch.cuda.mem_get_info() + used = total - available + print( + f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB" + ) diff --git a/step1x3d_texture/utils/ops.py b/step1x3d_texture/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8272799b06ebcc1231034908e085a04a123732d0 --- /dev/null +++ b/step1x3d_texture/utils/ops.py @@ -0,0 +1,462 @@ +import math +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +from .core import debug, find, info, warn +from .typing import * + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + elif name == "negative": + return lambda x: -x + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, + normalize: bool = True, +) -> Float[Tensor, "H W 3"]: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions: Float[Tensor, "H W 3"] = torch.stack( + [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 + ) + + if normalize: + directions = F.normalize(directions, dim=-1) + + return directions + + +def get_rays( + directions: Float[Tensor, "... 3"], + c2w: Float[Tensor, "... 4 4"], + keepdim=False, + noise_scale=0.0, + normalize=False, +) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + if normalize: + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_projection_matrix( + fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float +) -> Float[Tensor, "*B 4 4"]: + if isinstance(fovy, float): + proj_mtx = torch.zeros(4, 4, dtype=torch.float32) + proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) + proj_mtx[1, 1] = -1.0 / math.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[2, 2] = -(far + near) / (far - near) + proj_mtx[2, 3] = -2.0 * far * near / (far - near) + proj_mtx[3, 2] = -1.0 + else: + batch_size = fovy.shape[0] + proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) + proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) + proj_mtx[:, 1, 1] = -1.0 / torch.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[:, 2, 2] = -(far + near) / (far - near) + proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) + proj_mtx[:, 3, 2] = -1.0 + return proj_mtx + + +def get_mvp_matrix( + c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] +) -> Float[Tensor, "*B 4 4"]: + # calculate w2c from c2w: R' = Rt, t' = -Rt * t + # mathematically equivalent to (c2w)^-1 + if c2w.ndim == 2: + assert proj_mtx.ndim == 2 + w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) + w2c[:3, :3] = c2w[:3, :3].permute(1, 0) + w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] + w2c[3, 3] = 1.0 + else: + w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) + w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) + w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] + w2c[:, 3, 3] = 1.0 + # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) + mvp_mtx = proj_mtx @ w2c + return mvp_mtx + + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) + + +def binary_cross_entropy(input, target): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean() + + +def tet_sdf_diff( + vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"] +) -> Float[Tensor, ""]: + sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float() + ) + F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float() + ) + return sdf_diff + + +def validate_empty_rays(ray_indices, t_start, t_end): + if ray_indices.nelement() == 0: + warn("Empty rays_indices!") + ray_indices = torch.LongTensor([0]).to(ray_indices) + t_start = torch.Tensor([0]).to(ray_indices) + t_end = torch.Tensor([0]).to(ray_indices) + return ray_indices, t_start, t_end + + +def rays_intersect_bbox( + rays_o: Float[Tensor, "N 3"], + rays_d: Float[Tensor, "N 3"], + radius: Float, + near: Float = 0.0, + valid_thresh: Float = 0.01, + background: bool = False, +): + input_shape = rays_o.shape[:-1] + rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3) + rays_d_valid = torch.where( + rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d + ) + if type(radius) in [int, float]: + radius = torch.FloatTensor( + [[-radius, radius], [-radius, radius], [-radius, radius]] + ).to(rays_o.device) + radius = ( + 1.0 - 1.0e-3 + ) * radius # tighten the radius to make sure the intersection point lies in the bounding box + interx0 = (radius[..., 1] - rays_o) / rays_d_valid + interx1 = (radius[..., 0] - rays_o) / rays_d_valid + t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near) + t_far = torch.maximum(interx0, interx1).amin(dim=-1) + + # check wheter a ray intersects the bbox or not + rays_valid = t_far - t_near > valid_thresh + + t_near_valid, t_far_valid = t_near[rays_valid], t_far[rays_valid] + global_near = t_near_valid.min().item() + global_far = t_far_valid.max().item() + + t_near[torch.where(~rays_valid)] = 0.0 + t_far[torch.where(~rays_valid)] = 0.0 + + t_near = t_near.view(*input_shape, 1) + t_far = t_far.view(*input_shape, 1) + rays_valid = rays_valid.view(*input_shape) + + return t_near, t_far, rays_valid + + +def get_plucker_rays( + rays_o: Float[Tensor, "*N 3"], rays_d: Float[Tensor, "*N 3"] +) -> Float[Tensor, "*N 6"]: + rays_o = F.normalize(rays_o, dim=-1) + rays_d = F.normalize(rays_d, dim=-1) + return torch.cat([rays_o.cross(rays_d), rays_d], dim=-1) + + +def c2w_to_polar(c2w: Float[Tensor, "4 4"]) -> Tuple[float, float, float]: + cam_pos = c2w[:3, 3] + x, y, z = cam_pos.tolist() + distance = cam_pos.norm().item() + elevation = math.asin(z / distance) + if abs(x) < 1.0e-5 and abs(y) < 1.0e-5: + azimuth = 0 + else: + azimuth = math.atan2(y, x) + if azimuth < 0: + azimuth += 2 * math.pi + + return elevation, azimuth, distance + + +def polar_to_c2w( + elevation: float, azimuth: float, distance: float +) -> Float[Tensor, "4 4"]: + """ + Compute L = p - C. + Normalize L. + Compute s = L x u. (cross product) + Normalize s. + Compute u' = s x L. + rotation = [s, u, -l] + """ + z = distance * math.sin(elevation) + x = distance * math.cos(elevation) * math.cos(azimuth) + y = distance * math.cos(elevation) * math.sin(azimuth) + l = -torch.as_tensor([x, y, z]).float() + l = F.normalize(l, dim=0) + u = torch.as_tensor([0.0, 0.0, 1.0]).float() + s = l.cross(u) + s = F.normalize(s, dim=0) + u = s.cross(l) + rot = torch.stack([s, u, -l], dim=0).T + c2w = torch.zeros((4, 4), dtype=torch.float32) + c2w[:3, :3] = rot + c2w[:3, 3] = torch.as_tensor([x, y, z]) + c2w[3, 3] = 1 + return c2w + + +def fourier_position_encoding(x, n_freq: int, dim: int): + assert n_freq > 0 + input_shape = x.shape + input_ndim = x.ndim + if dim < 0: + dim = input_ndim + dim + bands = 2 ** torch.arange(n_freq, dtype=x.dtype, device=x.device) + for i in range(dim + 1): + bands = bands.unsqueeze(0) + for i in range(input_ndim - dim - 1): + bands = bands.unsqueeze(-1) + x = x.view(*input_shape[: dim + 1], 1, *input_shape[dim + 1 :]) + x = torch.cat( + [ + torch.sin(bands * x).reshape( + *input_shape[:dim], -1, *input_shape[dim + 1 :] + ), + torch.cos(bands * x).reshape( + *input_shape[:dim], -1, *input_shape[dim + 1 :] + ), + ], + dim=dim, + ) + return x diff --git a/step1x3d_texture/utils/render.py b/step1x3d_texture/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..e01b8226804614550b1a20c7c5699efd243029d4 --- /dev/null +++ b/step1x3d_texture/utils/render.py @@ -0,0 +1,520 @@ +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import List, Optional, Union + +import numpy as np +import nvdiffrast.torch as dr +import torch +import torch.nn.functional as F +import trimesh +from PIL import Image +from torch import BoolTensor, FloatTensor + +from . import logging +from .camera import Camera +import xatlas + +logger = logging.get_logger(__name__) + + +def dot(x: torch.FloatTensor, y: torch.FloatTensor) -> torch.FloatTensor: + return torch.sum(x * y, -1, keepdim=True) + + +@dataclass +class TexturedMesh: + v_pos: torch.FloatTensor + t_pos_idx: torch.LongTensor + + # texture coordinates + v_tex: Optional[torch.FloatTensor] = None + t_tex_idx: Optional[torch.LongTensor] = None + + # texture map + texture: Optional[torch.FloatTensor] = None + + # vertices, faces after vertex merging + _stitched_v_pos: Optional[torch.FloatTensor] = None + _stitched_t_pos_idx: Optional[torch.LongTensor] = None + + _v_nrm: Optional[torch.FloatTensor] = None + + @property + def v_nrm(self) -> torch.FloatTensor: + if self._v_nrm is None: + self._v_nrm = self._compute_vertex_normal() + return self._v_nrm + + def set_stitched_mesh( + self, v_pos: torch.FloatTensor, t_pos_idx: torch.LongTensor + ) -> None: + self._stitched_v_pos = v_pos + self._stitched_t_pos_idx = t_pos_idx + + @property + def stitched_v_pos(self) -> torch.FloatTensor: + if self._stitched_v_pos is None: + logger.warning("Stitched vertices not available, using original vertices!") + return self.v_pos + return self._stitched_v_pos + + @property + def stitched_t_pos_idx(self) -> torch.LongTensor: + if self._stitched_t_pos_idx is None: + logger.warning("Stitched faces not available, using original faces!") + return self.t_pos_idx + return self._stitched_t_pos_idx + + def _compute_vertex_normal(self) -> torch.FloatTensor: + if self._stitched_v_pos is None or self._stitched_t_pos_idx is None: + logger.warning( + "Stitched vertices and faces not available, computing vertex normals on original mesh, which can be erroneous!" + ) + v_pos, t_pos_idx = self.v_pos, self.t_pos_idx + else: + v_pos, t_pos_idx = self._stitched_v_pos, self._stitched_t_pos_idx + + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + def to(self, device: Optional[str] = None): + self.v_pos = self.v_pos.to(device) + self.t_pos_idx = self.t_pos_idx.to(device) + if self.v_tex is not None: + self.v_tex = self.v_tex.to(device) + if self.t_tex_idx is not None: + self.t_tex_idx = self.t_tex_idx.to(device) + if self.texture is not None: + self.texture = self.texture.to(device) + if self._stitched_v_pos is not None: + self._stitched_v_pos = self._stitched_v_pos.to(device) + if self._stitched_t_pos_idx is not None: + self._stitched_t_pos_idx = self._stitched_t_pos_idx.to(device) + if self._v_nrm is not None: + self._v_nrm = self._v_nrm.to(device) + + +def mesh_uv_wrap(mesh): + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + if len(mesh.faces) > 500000000: + raise ValueError( + "The mesh has more than 500,000,000 faces, which is not supported." + ) + + vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) + + mesh.vertices = mesh.vertices[vmapping] + mesh.faces = indices + mesh.visual.uv = uvs + + return mesh + + +def load_mesh( + mesh: str, + rescale: bool = False, + move_to_center: bool = False, + scale: float = 0.5, + flip_uv: bool = True, + merge_vertices: bool = True, + default_uv_size: int = 2048, + shape_init_mesh_up: str = "+y", + shape_init_mesh_front: str = "+x", + front_x_to_y: bool = False, + device: Optional[str] = None, + return_transform: bool = False, +) -> TexturedMesh: + scene = mesh + # scene = trimesh.load(mesh, force="mesh", process=False) + if isinstance(scene, trimesh.Trimesh): + mesh = scene + elif isinstance(scene, trimesh.scene.Scene): + mesh = trimesh.Trimesh() + for obj in scene.geometry.values(): + mesh = trimesh.util.concatenate([mesh, obj]) + else: + raise ValueError(f"Unknown mesh type at {mesh_path}.") + + # move to center + if move_to_center: + centroid = mesh.vertices.mean(0) + mesh.vertices = mesh.vertices - centroid + mesh = mesh_uv_wrap(mesh) + # rescale + if rescale: + max_scale = np.abs(mesh.vertices).max() + mesh.vertices = mesh.vertices / max_scale * scale + + mesh_bp = trimesh.base.Trimesh.copy(mesh) + dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] + dir2vec = { + "+x": np.array([1, 0, 0]), + "+y": np.array([0, 1, 0]), + "+z": np.array([0, 0, 1]), + "-x": np.array([-1, 0, 0]), + "-y": np.array([0, -1, 0]), + "-z": np.array([0, 0, -1]), + } + if shape_init_mesh_up not in dirs or shape_init_mesh_front not in dirs: + raise ValueError( + f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." + ) + if shape_init_mesh_up[1] == shape_init_mesh_front[1]: + raise ValueError( + "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." + ) + z_, x_ = ( + dir2vec[shape_init_mesh_up], + dir2vec[shape_init_mesh_front], + ) + y_ = np.cross(z_, x_) + std2mesh = np.stack([x_, y_, z_], axis=0).T + mesh2std = np.linalg.inv(std2mesh) + mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T + if front_x_to_y: + x = mesh.vertices[:, 1].copy() + y = -mesh.vertices[:, 0].copy() + mesh.vertices[:, 0] = x + mesh.vertices[:, 1] = y + + v_pos = torch.tensor(mesh.vertices, dtype=torch.float32) + t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64) + + if hasattr(mesh, "visual") and hasattr(mesh.visual, "uv"): + v_tex = torch.tensor(mesh.visual.uv, dtype=torch.float32) + if flip_uv: + v_tex[:, 1] = 1.0 - v_tex[:, 1] + t_tex_idx = t_pos_idx.clone() + # if ( + # hasattr(mesh.visual.material, "baseColorTexture") + # and mesh.visual.material.baseColorTexture + # ): + # texture = torch.tensor( + # np.array(mesh.visual.material.baseColorTexture) / 255.0, + # dtype=torch.float32, + # )[..., :3] + # else: + texture = torch.zeros( + (default_uv_size, default_uv_size, 3), dtype=torch.float32 + ) + else: + v_tex = None + t_tex_idx = None + texture = None + + textured_mesh = TexturedMesh( + v_pos=v_pos, + t_pos_idx=t_pos_idx, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + texture=texture, + ) + + if merge_vertices: + mesh.merge_vertices(merge_tex=True) + textured_mesh.set_stitched_mesh( + torch.tensor(mesh.vertices, dtype=torch.float32), + torch.tensor(mesh.faces, dtype=torch.int64), + ) + + textured_mesh.to(device) + + if return_transform: + return textured_mesh, np.array(centroid), max_scale / scale + + return textured_mesh, mesh_bp + + +@dataclass +class RenderOutput: + attr: Optional[torch.FloatTensor] = None + mask: Optional[torch.BoolTensor] = None + depth: Optional[torch.FloatTensor] = None + normal: Optional[torch.FloatTensor] = None + pos: Optional[torch.FloatTensor] = None + + +class NVDiffRastContextWrapper: + def __init__(self, device: str, context_type: str = "gl"): + if context_type == "gl": + self.ctx = dr.RasterizeGLContext(device=device) + elif context_type == "cuda": + self.ctx = dr.RasterizeCudaContext(device=device) + else: + raise NotImplementedError + + def rasterize(self, pos, tri, resolution, ranges=None, grad_db=True): + """ + Rasterize triangles. + + All input tensors must be contiguous and reside in GPU memory except for the ranges tensor that, if specified, has to reside in CPU memory. The output tensors will be contiguous and reside in GPU memory. + + Arguments: + glctx Rasterizer context of type RasterizeGLContext or RasterizeCudaContext. + pos Vertex position tensor with dtype torch.float32. To enable range mode, this tensor should have a 2D shape [num_vertices, 4]. To enable instanced mode, use a 3D shape [minibatch_size, num_vertices, 4]. + tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32. + resolution Output resolution as integer tuple (height, width). + ranges In range mode, tensor with shape [minibatch_size, 2] and dtype torch.int32, specifying start indices and counts into tri. Ignored in instanced mode. + grad_db Propagate gradients of image-space derivatives of barycentrics into pos in backward pass. Ignored if using an OpenGL context that was not configured to output image-space derivatives. + Returns: + A tuple of two tensors. The first output tensor has shape [minibatch_size, height, width, 4] and contains the main rasterizer output in order (u, v, z/w, triangle_id). If the OpenGL context was configured to output image-space derivatives of barycentrics, the second output tensor will also have shape [minibatch_size, height, width, 4] and contain said derivatives in order (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape [minibatch_size, height, width, 0]. + """ + return dr.rasterize( + self.ctx, pos.float(), tri.int(), resolution, ranges, grad_db + ) + + def interpolate(self, attr, rast, tri, rast_db=None, diff_attrs=None): + """ + Interpolate vertex attributes. + + All input tensors must be contiguous and reside in GPU memory. The output tensors will be contiguous and reside in GPU memory. + + Arguments: + attr Attribute tensor with dtype torch.float32. Shape is [num_vertices, num_attributes] in range mode, or [minibatch_size, num_vertices, num_attributes] in instanced mode. Broadcasting is supported along the minibatch axis. + rast Main output tensor from rasterize(). + tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32. + rast_db (Optional) Tensor containing image-space derivatives of barycentrics, i.e., the second output tensor from rasterize(). Enables computing image-space derivatives of attributes. + diff_attrs (Optional) List of attribute indices for which image-space derivatives are to be computed. Special value 'all' is equivalent to list [0, 1, ..., num_attributes - 1]. + Returns: + A tuple of two tensors. The first output tensor contains interpolated attributes and has shape [minibatch_size, height, width, num_attributes]. If rast_db and diff_attrs were specified, the second output tensor contains the image-space derivatives of the selected attributes and has shape [minibatch_size, height, width, 2 * len(diff_attrs)]. The derivatives of the first selected attribute A will be on channels 0 and 1 as (dA/dX, dA/dY), etc. Otherwise, the second output tensor will be an empty tensor with shape [minibatch_size, height, width, 0]. + """ + return dr.interpolate(attr.float(), rast, tri.int(), rast_db, diff_attrs) + + def texture( + self, + tex, + uv, + uv_da=None, + mip_level_bias=None, + mip=None, + filter_mode="auto", + boundary_mode="wrap", + max_mip_level=None, + ): + """ + Perform texture sampling. + + All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory. + + Arguments: + tex Texture tensor with dtype torch.float32. For 2D textures, must have shape [minibatch_size, tex_height, tex_width, tex_channels]. For cube map textures, must have shape [minibatch_size, 6, tex_height, tex_width, tex_channels] where tex_width and tex_height are equal. Note that boundary_mode must also be set to 'cube' to enable cube map mode. Broadcasting is supported along the minibatch axis. + uv Tensor containing per-pixel texture coordinates. When sampling a 2D texture, must have shape [minibatch_size, height, width, 2]. When sampling a cube map texture, must have shape [minibatch_size, height, width, 3]. + uv_da (Optional) Tensor containing image-space derivatives of texture coordinates. Must have same shape as uv except for the last dimension that is to be twice as long. + mip_level_bias (Optional) Per-pixel bias for mip level selection. If uv_da is omitted, determines mip level directly. Must have shape [minibatch_size, height, width]. + mip (Optional) Preconstructed mipmap stack from a texture_construct_mip() call, or a list of tensors specifying a custom mipmap stack. When specifying a custom mipmap stack, the tensors in the list must follow the same format as tex except for width and height that must follow the usual rules for mipmap sizes. The base level texture is still supplied in tex and must not be included in the list. Gradients of a custom mipmap stack are not automatically propagated to base texture but the mipmap tensors will receive gradients of their own. If a mipmap stack is not specified but the chosen filter mode requires it, the mipmap stack is constructed internally and discarded afterwards. + filter_mode Texture filtering mode to be used. Valid values are 'auto', 'nearest', 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' selects 'linear' if neither uv_da or mip_level_bias is specified, and 'linear-mipmap-linear' when at least one of them is specified, these being the highest-quality modes possible depending on the availability of the image-space derivatives of the texture coordinates or direct mip level information. + boundary_mode Valid values are 'wrap', 'clamp', 'zero', and 'cube'. If tex defines a cube map, this must be set to 'cube'. The default mode 'wrap' takes fractional part of texture coordinates. Mode 'clamp' clamps texture coordinates to the centers of the boundary texels. Mode 'zero' virtually extends the texture with all-zero values in all directions. + max_mip_level If specified, limits the number of mipmaps constructed and used in mipmap-based filter modes. + Returns: + A tensor containing the results of the texture sampling with shape [minibatch_size, height, width, tex_channels]. Cube map fetches with invalid uv coordinates (e.g., zero vectors) output all zeros and do not propagate gradients. + """ + return dr.texture( + tex.float(), + uv.float(), + uv_da, + mip_level_bias, + mip, + filter_mode, + boundary_mode, + max_mip_level, + ) + + def antialias( + self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0 + ): + """ + Perform antialiasing. + + All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory. + + Note that silhouette edge determination is based on vertex indices in the triangle tensor. For it to work properly, a vertex belonging to multiple triangles must be referred to using the same vertex index in each triangle. Otherwise, nvdiffrast will always classify the adjacent edges as silhouette edges, which leads to bad performance and potentially incorrect gradients. If you are unsure whether your data is good, check which pixels are modified by the antialias operation and compare to the example in the documentation. + + Arguments: + color Input image to antialias with shape [minibatch_size, height, width, num_channels]. + rast Main output tensor from rasterize(). + pos Vertex position tensor used in the rasterization operation. + tri Triangle tensor used in the rasterization operation. + topology_hash (Optional) Preconstructed topology hash for the triangle tensor. If not specified, the topology hash is constructed internally and discarded afterwards. + pos_gradient_boost (Optional) Multiplier for gradients propagated to pos. + Returns: + A tensor containing the antialiased image with the same shape as color input tensor. + """ + return dr.antialias( + color.float(), + rast, + pos.float(), + tri.int(), + topology_hash, + pos_gradient_boost, + ) + + +def get_clip_space_position(pos: torch.FloatTensor, mvp_mtx: torch.FloatTensor): + pos_homo = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos)], dim=-1) + return torch.matmul(pos_homo, mvp_mtx.permute(0, 2, 1)) + + +def transform_points_homo(pos: torch.FloatTensor, mtx: torch.FloatTensor): + batch_size = pos.shape[0] + pos_shape = pos.shape[1:-1] + pos = pos.reshape(batch_size, -1, 3) + pos_homo = torch.cat([pos, torch.ones_like(pos[..., 0:1])], dim=-1) + pos = (pos_homo.unsqueeze(2) * mtx.unsqueeze(1)).sum(-1)[..., :3] + pos = pos.reshape(batch_size, *pos_shape, 3) + return pos + + +class DepthNormalizationStrategy(ABC): + @abstractmethod + def __init__(self, *args, **kwargs): + pass + + @abstractmethod + def __call__( + self, depth: torch.FloatTensor, mask: torch.BoolTensor + ) -> torch.FloatTensor: + pass + + +class DepthControlNetNormalization(DepthNormalizationStrategy): + def __init__( + self, far_clip: float = 0.25, near_clip: float = 1.0, bg_value: float = 0.0 + ): + self.far_clip = far_clip + self.near_clip = near_clip + self.bg_value = bg_value + + def __call__( + self, depth: torch.FloatTensor, mask: torch.BoolTensor + ) -> torch.FloatTensor: + batch_size = depth.shape[0] + min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None] + max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None] + depth = 1.0 - ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp( + 0.0, 1.0 + ) + depth = depth * (self.near_clip - self.far_clip) + self.far_clip + depth[~mask] = self.bg_value + return depth + + +class Zero123PlusPlusNormalization(DepthNormalizationStrategy): + def __init__(self, bg_value: float = 0.8): + self.bg_value = bg_value + + def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor: + batch_size = depth.shape[0] + min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None] + max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None] + depth = ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp(0.0, 1.0) + depth[~mask] = self.bg_value + return depth + + +class SimpleNormalization(DepthNormalizationStrategy): + def __init__( + self, + scale: float = 1.0, + offset: float = -1.0, + clamp: bool = True, + bg_value: float = 1.0, + ): + self.scale = scale + self.offset = offset + self.clamp = clamp + self.bg_value = bg_value + + def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor: + depth = depth * self.scale + self.offset + if self.clamp: + depth = depth.clamp(0.0, 1.0) + depth[~mask] = self.bg_value + return depth + + +def render( + ctx: NVDiffRastContextWrapper, + mesh: TexturedMesh, + cam: Camera, + height: int, + width: int, + render_attr: bool = True, + render_depth: bool = True, + render_normal: bool = True, + depth_normalization_strategy: DepthNormalizationStrategy = DepthControlNetNormalization(), + attr_background: Union[float, torch.FloatTensor] = 0.5, + antialias_attr=False, + normal_background: Union[float, torch.FloatTensor] = 0.5, + texture_override=None, + texture_filter_mode: str = "linear", +) -> RenderOutput: + output_dict = {} + + v_pos_clip = get_clip_space_position(mesh.v_pos, cam.mvp_mtx) + rast, _ = ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width), grad_db=True) + mask = rast[..., 3] > 0 + + gb_pos, _ = ctx.interpolate(mesh.v_pos[None], rast, mesh.t_pos_idx) + output_dict.update({"mask": mask, "pos": gb_pos}) + + if render_depth: + gb_pos_vs = transform_points_homo(gb_pos, cam.w2c) + gb_depth = -gb_pos_vs[..., 2] + # set background pixels to min depth value for correct min/max calculation + gb_depth = torch.where( + mask, + gb_depth, + gb_depth.view(gb_depth.shape[0], -1).min(dim=-1)[0][:, None, None], + ) + gb_depth = depth_normalization_strategy(gb_depth, mask) + output_dict["depth"] = gb_depth + + if render_attr: + tex_c, _ = ctx.interpolate(mesh.v_tex[None], rast, mesh.t_tex_idx) + texture = ( + texture_override[None] + if texture_override is not None + else mesh.texture[None] + ) + gb_rgb_fg = ctx.texture(texture, tex_c, filter_mode=texture_filter_mode) + gb_rgb_bg = torch.ones_like(gb_rgb_fg) * attr_background + gb_rgb = torch.where(mask[..., None], gb_rgb_fg, gb_rgb_bg) + if antialias_attr: + gb_rgb = ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) + output_dict["attr"] = gb_rgb + + if render_normal: + gb_nrm, _ = ctx.interpolate(mesh.v_nrm[None], rast, mesh.stitched_t_pos_idx) + gb_nrm = F.normalize(gb_nrm, dim=-1, p=2) + gb_nrm[~mask] = normal_background + output_dict["normal"] = gb_nrm + + return RenderOutput(**output_dict) diff --git a/step1x3d_texture/utils/saving.py b/step1x3d_texture/utils/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ce1ae0baeeab5e714c3e5a038cc2dfaf0aa626 --- /dev/null +++ b/step1x3d_texture/utils/saving.py @@ -0,0 +1,533 @@ +import json +import math +import os +import re +import shutil +from typing import List, Optional, Union + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch + +# import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw + +from .typing import * + + +def tensor_to_image( + data: Union[Image.Image, torch.Tensor, np.ndarray], + batched: bool = False, + format: str = "HWC", +) -> Union[Image.Image, List[Image.Image]]: + if isinstance(data, Image.Image): + return data + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + if data.dtype == np.float32 or data.dtype == np.float16: + data = (data * 255).astype(np.uint8) + elif data.dtype == np.bool_: + data = data.astype(np.uint8) * 255 + assert data.dtype == np.uint8 + if format == "CHW": + if batched and data.ndim == 4: + data = data.transpose((0, 2, 3, 1)) + elif not batched and data.ndim == 3: + data = data.transpose((1, 2, 0)) + + if batched: + return [Image.fromarray(d) for d in data] + return Image.fromarray(data) + + +def largest_factor_near_sqrt(n: int) -> int: + """ + Finds the largest factor of n that is closest to the square root of n. + + Args: + n (int): The integer for which to find the largest factor near its square root. + + Returns: + int: The largest factor of n that is closest to the square root of n. + """ + sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root + + # First, check if the square root itself is a factor + if sqrt_n * sqrt_n == n: + return sqrt_n + + # Otherwise, find the largest factor by iterating from sqrt_n downwards + for i in range(sqrt_n, 0, -1): + if n % i == 0: + return i + + # If n is 1, return 1 + return 1 + + +def make_image_grid( + images: List[Image.Image], + rows: Optional[int] = None, + cols: Optional[int] = None, + resize: Optional[int] = None, +) -> Image.Image: + """ + Prepares a single grid of images. Useful for visualization purposes. + """ + if rows is None and cols is not None: + assert len(images) % cols == 0 + rows = len(images) // cols + elif cols is None and rows is not None: + assert len(images) % rows == 0 + cols = len(images) // rows + elif rows is None and cols is None: + rows = largest_factor_near_sqrt(len(images)) + cols = len(images) // rows + + assert len(images) == rows * cols + + if resize is not None: + images = [img.resize((resize, resize)) for img in images] + + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(images): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[Any] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + if data.dtype in [torch.float16, torch.bfloat16]: + data = data.float() + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + ( + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + elif isinstance(align, int): + h = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, or int" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h: + w = int(cols[i].shape[1] * h / cols[i].shape[0]) + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + self._wandb_logger.log_image(key=name, images=[save_path], step=step) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + from .core import warn + + warn("Wandb logger does not support video logging yet!") + return save_path + + def save_img_sequences( + self, + seq_dir, + matcher, + save_format="mp4", + fps=30, + delete=True, + name: Optional[str] = None, + step: Optional[int] = None, + ): + seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) + for f in os.listdir(seq_dir_): + img_dir_ = os.path.join(seq_dir_, f) + if not os.path.isdir(img_dir_): + continue + try: + self.save_img_sequence( + os.path.join(seq_dir, f), + os.path.join(seq_dir, f), + matcher, + save_format=save_format, + fps=fps, + name=f"{name}_{f}", + step=step, + ) + if delete: + shutil.rmtree(img_dir_) + except: + from .core import warn + + warn(f"Video saving for directory {seq_dir_} failed!") + + def save_file(self, filename, src_path, delete=False) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + if delete: + os.remove(src_path) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path diff --git a/step1x3d_texture/utils/shape_post_process.py b/step1x3d_texture/utils/shape_post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..135573e535f724369bfd0c07abe01db780c3d682 --- /dev/null +++ b/step1x3d_texture/utils/shape_post_process.py @@ -0,0 +1,145 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import tempfile +from typing import Union + +import numpy as np +import pymeshlab +import torch +import trimesh + + +def load_mesh(path): + if path.endswith(".glb"): + mesh = trimesh.load(path) + else: + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(path) + return mesh + + +def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000): + if max_facenum > mesh.current_mesh().face_number(): + return mesh + + mesh.apply_filter( + "meshing_decimation_quadric_edge_collapse", + targetfacenum=max_facenum, + qualitythr=1.0, + preserveboundary=True, + boundaryweight=3, + preservenormal=True, + preservetopology=True, + autoclean=True, + ) + return mesh + + +def remove_floater(mesh: pymeshlab.MeshSet): + mesh.apply_filter( + "compute_selection_by_small_disconnected_components_per_face", nbfaceratio=0.005 + ) + mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False) + mesh.apply_filter("meshing_remove_selected_vertices_and_faces") + return mesh + + +def pymeshlab2trimesh(mesh: pymeshlab.MeshSet): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + mesh.save_current_mesh(temp_file.name) + mesh = trimesh.load(temp_file.name) + # 检查加载的对象类型 + if isinstance(mesh, trimesh.Scene): + combined_mesh = trimesh.Trimesh() + # 如果是Scene,遍历所有的geometry并合并 + for geom in mesh.geometry.values(): + combined_mesh = trimesh.util.concatenate([combined_mesh, geom]) + mesh = combined_mesh + return mesh + + +def trimesh2pymeshlab(mesh: trimesh.Trimesh): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + if isinstance(mesh, trimesh.scene.Scene): + for idx, obj in enumerate(mesh.geometry.values()): + if idx == 0: + temp_mesh = obj + else: + temp_mesh = temp_mesh + obj + mesh = temp_mesh + mesh.export(temp_file.name) + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(temp_file.name) + return mesh + + +def import_mesh( + mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, str], +) -> pymeshlab.MeshSet: + if isinstance(mesh, str): + mesh = load_mesh(mesh) + + if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)): + mesh = trimesh2pymeshlab(mesh) + + return mesh + + +def export_mesh(input, output): + if isinstance(input, pymeshlab.MeshSet): + mesh = output + else: + mesh = pymeshlab2trimesh(output) + return mesh + + +class FaceReducer: + def __call__( + self, + mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, str], + max_facenum: int = 40000, + ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: + ms = import_mesh(mesh) + ms = reduce_face(ms, max_facenum=max_facenum) + mesh = export_mesh(mesh, ms) + return mesh + + +class FloaterRemover: + def __call__( + self, + mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, str], + ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: + ms = import_mesh(mesh) + ms = remove_floater(ms) + mesh = export_mesh(mesh, ms) + return mesh + + +class DegenerateFaceRemover: + def __call__( + self, + mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, str], + ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: + ms = import_mesh(mesh) + + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + ms.save_current_mesh(temp_file.name) + ms = pymeshlab.MeshSet() + ms.load_new_mesh(temp_file.name) + + mesh = export_mesh(mesh, ms) + return mesh diff --git a/step1x3d_texture/utils/typing.py b/step1x3d_texture/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bb88ebc931bf49462c973d89980c671ce64915 --- /dev/null +++ b/step1x3d_texture/utils/typing.py @@ -0,0 +1,40 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig, ListConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker