Spaces:
Runtime error
Runtime error
import torchvision.transforms.functional as F | |
import torch | |
import pickle | |
from tops import download_file, assert_shape | |
from typing import Dict | |
from functools import lru_cache | |
global symmetry_transform | |
def get_symmetry_transform(symmetry_url): | |
file_name = download_file(symmetry_url) | |
with open(file_name, "rb") as fp: | |
symmetry = pickle.load(fp) | |
return torch.from_numpy(symmetry["vertex_transforms"]).long() | |
hflip_handled_cases = set([ | |
"keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition", | |
"embedding", "vertx2cat", "maskrcnn_mask", "__key__"]) | |
def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]: | |
container["img"] = F.hflip(container["img"]) | |
if "condition" in container: | |
container["condition"] = F.hflip(container["condition"]) | |
if "embedding" in container: | |
container["embedding"] = F.hflip(container["embedding"]) | |
assert all([key in hflip_handled_cases for key in container]), container.keys() | |
if "keypoints" in container: | |
assert flip_map is not None | |
if container["keypoints"].ndim == 3: | |
keypoints = container["keypoints"][:, flip_map, :] | |
keypoints[:, :, 0] = 1 - keypoints[:, :, 0] | |
else: | |
assert_shape(container["keypoints"], (None, 3)) | |
keypoints = container["keypoints"][flip_map, :] | |
keypoints[:, 0] = 1 - keypoints[:, 0] | |
container["keypoints"] = keypoints | |
if "mask" in container: | |
container["mask"] = F.hflip(container["mask"]) | |
if "border" in container: | |
container["border"] = F.hflip(container["border"]) | |
if "semantic_mask" in container: | |
container["semantic_mask"] = F.hflip(container["semantic_mask"]) | |
if "vertices" in container: | |
symmetry_transform = get_symmetry_transform( | |
"https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl") | |
container["vertices"] = F.hflip(container["vertices"]) | |
symmetry_transform_ = symmetry_transform.to(container["vertices"].device) | |
container["vertices"] = symmetry_transform_[container["vertices"].long()] | |
if "E_mask" in container: | |
container["E_mask"] = F.hflip(container["E_mask"]) | |
if "maskrcnn_mask" in container: | |
container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"]) | |
return container | |