Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import torch | |
from torch import nn | |
import numpy as np | |
import re | |
from einops import rearrange | |
from dataclasses import dataclass | |
from torchvision import transforms | |
from diffusers.models.modeling_utils import ModelMixin | |
from transformers.utils import ModelOutput | |
from typing import Iterable, Optional, Union, List | |
import step1x3d_geometry | |
from step1x3d_geometry.utils.typing import * | |
from step1x3d_geometry.utils.misc import get_device | |
from .base import BaseLabelEncoder | |
DEFAULT_POSE = 0 # "unknown", "t-pose", "a-pose", uncond | |
NUM_POSE_CLASSES = 3 | |
POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3} | |
DEFAULT_SYMMETRY_TYPE = 0 # "asymmetry", "x", uncond | |
NUM_SYMMETRY_TYPE_CLASSES = 2 | |
SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2} | |
DEFAULT_GEOMETRY_QUALITY = 0 # "normal", "smooth", "sharp", uncond, | |
NUM_GEOMETRY_QUALITY_CLASSES = 3 | |
GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3} | |
class LabelEncoder(BaseLabelEncoder, ModelMixin): | |
""" | |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
Args: | |
num_classes (`int`): The number of classes. | |
hidden_size (`int`): The size of the vector embeddings. | |
""" | |
def configure(self) -> None: | |
super().configure() | |
if self.cfg.zero_uncond_embeds: | |
self.embedding_table_tpose = nn.Embedding( | |
NUM_POSE_CLASSES, self.cfg.hidden_size | |
) | |
self.embedding_table_symmetry_type = nn.Embedding( | |
NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size | |
) | |
self.embedding_table_geometry_quality = nn.Embedding( | |
NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size | |
) | |
else: | |
self.embedding_table_tpose = nn.Embedding( | |
NUM_POSE_CLASSES + 1, self.cfg.hidden_size | |
) | |
self.embedding_table_symmetry_type = nn.Embedding( | |
NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size | |
) | |
self.embedding_table_geometry_quality = nn.Embedding( | |
NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size | |
) | |
if self.cfg.zero_uncond_embeds: | |
self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach() | |
else: | |
self.empty_label_embeds = ( | |
self.encode_label( # the last class label is for the uncond | |
[{"pose": "", "symetry": "", "geometry_type": ""}] | |
).detach() | |
) | |
# load pretrained_model_name_or_path | |
if self.cfg.pretrained_model_name_or_path is not None: | |
print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") | |
ckpt = torch.load( | |
self.cfg.pretrained_model_name_or_path, map_location="cpu" | |
)["state_dict"] | |
pretrained_model_ckpt = {} | |
for k, v in ckpt.items(): | |
if k.startswith("label_condition."): | |
pretrained_model_ckpt[k.replace("label_condition.", "")] = v | |
self.load_state_dict(pretrained_model_ckpt, strict=True) | |
def encode_label(self, labels: List[dict]) -> torch.FloatTensor: | |
tpose_label_embeds = [] | |
symmetry_type_label_embeds = [] | |
geometry_quality_label_embeds = [] | |
for label in labels: | |
if "pose" in label.keys(): | |
if label["pose"] is None or label["pose"] == "": | |
tpose_label_embeds.append( | |
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
) | |
else: | |
tpose_label_embeds.append( | |
self.embedding_table_symmetry_type( | |
torch.tensor(POSE_MAPPING[label["pose"][0]]).to( | |
get_device() | |
) | |
) | |
) | |
else: | |
tpose_label_embeds.append( | |
self.embedding_table_tpose( | |
torch.tensor(DEFAULT_POSE).to(get_device()) | |
) | |
) | |
if "symmetry" in label.keys(): | |
if label["symmetry"] is None or label["symmetry"] == "": | |
symmetry_type_label_embeds.append( | |
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
) | |
else: | |
symmetry_type_label_embeds.append( | |
self.embedding_table_symmetry_type( | |
torch.tensor( | |
SYMMETRY_TYPE_MAPPING[label["symmetry"]] | |
).to(get_device()) | |
) | |
) | |
else: | |
symmetry_type_label_embeds.append( | |
self.embedding_table_symmetry_type( | |
torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device()) | |
) | |
) | |
if "geometry_type" in label.keys(): | |
if label["geometry_type"] is None or label["geometry_type"] == "": | |
geometry_quality_label_embeds.append( | |
torch.zeros(self.cfg.hidden_size).detach().to(get_device()) | |
) | |
else: | |
geometry_quality_label_embeds.append( | |
self.embedding_table_geometry_quality( | |
torch.tensor( | |
GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]] | |
).to(get_device()) | |
) | |
) | |
else: | |
geometry_quality_label_embeds.append( | |
self.embedding_table_geometry_quality( | |
torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device()) | |
) | |
) | |
tpose_label_embeds = torch.stack(tpose_label_embeds) | |
symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds) | |
geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds) | |
label_embeds = torch.stack( | |
[ | |
tpose_label_embeds, | |
symmetry_type_label_embeds, | |
geometry_quality_label_embeds, | |
], | |
dim=1, | |
).to(self.dtype) | |
return label_embeds | |