Spaces:
mucusz
/
Runtime error

File size: 4,403 Bytes
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487ee6d
 
 
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
 
 
 
 
 
fb140f6
 
da48dbe
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
fb140f6
 
 
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
 
da48dbe
 
fb140f6
 
 
 
da48dbe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at [email protected]
# For commercial licensing contact, please contact [email protected]

import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class FLAMETex(nn.Module):
    """
    FLAME texture:
    https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64
    FLAME texture converted from BFM:
    https://github.com/TimoBolkart/BFM_to_FLAME
    """
    def __init__(self, config):
        super(FLAMETex, self).__init__()
        if config.tex_type == "BFM":
            mu_key = "MU"
            pc_key = "PC"
            n_pc = 199
            tex_path = config.tex_path
            tex_space = np.load(tex_path)
            texture_mean = tex_space[mu_key].reshape(1, -1)
            texture_basis = tex_space[pc_key].reshape(-1, n_pc)

        elif config.tex_type == "FLAME":
            mu_key = "mean"
            pc_key = "tex_dir"
            n_pc = 200
            tex_path = config.flame_tex_path
            tex_space = np.load(tex_path)
            texture_mean = tex_space[mu_key].reshape(1, -1) / 255.0
            texture_basis = tex_space[pc_key].reshape(-1, n_pc) / 255.0
        else:
            print("texture type ", config.tex_type, "not exist!")
            raise NotImplementedError

        n_tex = config.n_tex
        num_components = texture_basis.shape[1]
        texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
        texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...]
        self.register_buffer("texture_mean", texture_mean)
        self.register_buffer("texture_basis", texture_basis)

    def forward(self, texcode=None):
        """
        texcode: [batchsize, n_tex]
        texture: [bz, 3, 256, 256], range: 0-1
        """
        texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1)
        texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2)
        texture = F.interpolate(texture, [256, 256])
        texture = texture[:, [2, 1, 0], :, :]
        return texture


def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
    """Convert flame texture map (face-only) into smplx texture map (includes body texture)
    TODO: pytorch version ==> grid sample
    """
    if smplx_texture.shape[0] != smplx_texture.shape[1]:
        print("SMPL-X texture not squared (%d != %d)" % (smplx_texture[0], smplx_texture[1]))
        return
    if smplx_texture.shape[0] != cached_data["target_resolution"]:
        print(
            "SMPL-X texture size does not match cached image resolution (%d != %d)" %
            (smplx_texture.shape[0], cached_data["target_resolution"])
        )
        return
    x_coords = cached_data["x_coords"]
    y_coords = cached_data["y_coords"]
    target_pixel_ids = cached_data["target_pixel_ids"]
    source_uv_points = cached_data["source_uv_points"]

    source_tex_coords = np.zeros_like((source_uv_points)).astype(int)
    source_tex_coords[:, 0] = np.clip(
        flame_texture.shape[0] * (1.0 - source_uv_points[:, 1]),
        0.0,
        flame_texture.shape[0],
    ).astype(int)
    source_tex_coords[:, 1] = np.clip(
        flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, flame_texture.shape[1]
    ).astype(int)

    smplx_texture[y_coords[target_pixel_ids].astype(int),
                  x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[source_tex_coords[:,
                                                                                                 0],
                                                                               source_tex_coords[:,
                                                                                                 1]]

    return smplx_texture