|
import random |
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
import re |
|
from einops import rearrange |
|
from dataclasses import dataclass |
|
from torchvision import transforms |
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
from transformers.utils import ModelOutput |
|
from typing import Iterable, Optional, Union, List |
|
|
|
import step1x3d_geometry |
|
from step1x3d_geometry.utils.typing import * |
|
from step1x3d_geometry.utils.misc import get_device |
|
|
|
from .base import BaseLabelEncoder |
|
|
|
DEFAULT_POSE = 0 |
|
NUM_POSE_CLASSES = 3 |
|
POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3} |
|
|
|
DEFAULT_SYMMETRY_TYPE = 0 |
|
NUM_SYMMETRY_TYPE_CLASSES = 2 |
|
SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2} |
|
|
|
DEFAULT_GEOMETRY_QUALITY = 0 |
|
NUM_GEOMETRY_QUALITY_CLASSES = 3 |
|
GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3} |
|
|
|
|
|
@step1x3d_geometry.register("label-encoder") |
|
class LabelEncoder(BaseLabelEncoder, ModelMixin): |
|
""" |
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
|
|
|
Args: |
|
num_classes (`int`): The number of classes. |
|
hidden_size (`int`): The size of the vector embeddings. |
|
""" |
|
|
|
def configure(self) -> None: |
|
super().configure() |
|
|
|
if self.cfg.zero_uncond_embeds: |
|
self.embedding_table_tpose = nn.Embedding( |
|
NUM_POSE_CLASSES, self.cfg.hidden_size |
|
) |
|
self.embedding_table_symmetry_type = nn.Embedding( |
|
NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size |
|
) |
|
self.embedding_table_geometry_quality = nn.Embedding( |
|
NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size |
|
) |
|
else: |
|
self.embedding_table_tpose = nn.Embedding( |
|
NUM_POSE_CLASSES + 1, self.cfg.hidden_size |
|
) |
|
self.embedding_table_symmetry_type = nn.Embedding( |
|
NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size |
|
) |
|
self.embedding_table_geometry_quality = nn.Embedding( |
|
NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size |
|
) |
|
|
|
if self.cfg.zero_uncond_embeds: |
|
self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach() |
|
else: |
|
self.empty_label_embeds = ( |
|
self.encode_label( |
|
[{"pose": "", "symetry": "", "geometry_type": ""}] |
|
).detach() |
|
) |
|
|
|
|
|
if self.cfg.pretrained_model_name_or_path is not None: |
|
print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") |
|
ckpt = torch.load( |
|
self.cfg.pretrained_model_name_or_path, map_location="cpu" |
|
)["state_dict"] |
|
pretrained_model_ckpt = {} |
|
for k, v in ckpt.items(): |
|
if k.startswith("label_condition."): |
|
pretrained_model_ckpt[k.replace("label_condition.", "")] = v |
|
self.load_state_dict(pretrained_model_ckpt, strict=True) |
|
|
|
def encode_label(self, labels: List[dict]) -> torch.FloatTensor: |
|
tpose_label_embeds = [] |
|
symmetry_type_label_embeds = [] |
|
geometry_quality_label_embeds = [] |
|
|
|
for label in labels: |
|
if "pose" in label.keys(): |
|
if label["pose"] is None or label["pose"] == "": |
|
tpose_label_embeds.append( |
|
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
|
) |
|
else: |
|
tpose_label_embeds.append( |
|
self.embedding_table_symmetry_type( |
|
torch.tensor(POSE_MAPPING[label["pose"][0]]).to( |
|
get_device() |
|
) |
|
) |
|
) |
|
else: |
|
tpose_label_embeds.append( |
|
self.embedding_table_tpose( |
|
torch.tensor(DEFAULT_POSE).to(get_device()) |
|
) |
|
) |
|
|
|
if "symmetry" in label.keys(): |
|
if label["symmetry"] is None or label["symmetry"] == "": |
|
symmetry_type_label_embeds.append( |
|
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
|
) |
|
else: |
|
symmetry_type_label_embeds.append( |
|
self.embedding_table_symmetry_type( |
|
torch.tensor( |
|
SYMMETRY_TYPE_MAPPING[label["symmetry"][0]] |
|
).to(get_device()) |
|
) |
|
) |
|
else: |
|
symmetry_type_label_embeds.append( |
|
self.embedding_table_symmetry_type( |
|
torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device()) |
|
) |
|
) |
|
|
|
if "geometry_type" in label.keys(): |
|
if label["geometry_type"] is None or label["geometry_type"] == "": |
|
geometry_quality_label_embeds.append( |
|
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
|
) |
|
else: |
|
geometry_quality_label_embeds.append( |
|
self.embedding_table_geometry_quality( |
|
torch.tensor( |
|
GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]] |
|
).to(get_device()) |
|
) |
|
) |
|
else: |
|
geometry_quality_label_embeds.append( |
|
self.embedding_table_geometry_quality( |
|
torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device()) |
|
) |
|
) |
|
|
|
tpose_label_embeds = torch.stack(tpose_label_embeds) |
|
symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds) |
|
geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds) |
|
|
|
label_embeds = torch.stack( |
|
[ |
|
tpose_label_embeds, |
|
symmetry_type_label_embeds, |
|
geometry_quality_label_embeds, |
|
], |
|
dim=1, |
|
).to(self.dtype) |
|
|
|
return label_embeds |
|
|