File size: 3,803 Bytes
8a6df40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61990d5
8a6df40
61990d5
8a6df40
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import struct
import cv2
import numpy as np
import os
import tempfile
import torch

from posenet import MobileNetV1, MOBILENET_V1_CHECKPOINTS


BASE_DIR = os.path.join(tempfile.gettempdir(), '_posenet_weights')


def to_torch_name(tf_name):
    tf_name = tf_name.lower()
    tf_split = tf_name.split('/')
    tf_layer_split = tf_split[1].split('_')
    tf_variable_type = tf_split[2]
    if tf_variable_type == 'weights' or tf_variable_type == 'depthwise_weights':
        variable_postfix = '.weight'
    elif tf_variable_type == 'biases':
        variable_postfix = '.bias'
    else:
        variable_postfix = ''

    if tf_layer_split[0] == 'conv2d':
        torch_name = 'features.conv' + tf_layer_split[1]
        if len(tf_layer_split) > 2:
            torch_name += '.' + tf_layer_split[2]
        else:
            torch_name += '.conv'
        torch_name += variable_postfix
    else:
        if tf_layer_split[0] in ['offset', 'displacement', 'heatmap'] and tf_layer_split[-1] == '2':
            torch_name = '_'.join(tf_layer_split[:-1])
            torch_name += variable_postfix
        else:
            torch_name = ''

    return torch_name


def load_variables(chkpoint, base_dir=BASE_DIR):
    manifest_path = os.path.join(base_dir, chkpoint, "manifest.json")
    if not os.path.exists(manifest_path):
        print('Weights for checkpoint %s are not downloaded. Downloading to %s ...' % (chkpoint, base_dir))
        from posenet.converter.wget import download
        download(chkpoint, base_dir)
        assert os.path.exists(manifest_path)

    manifest = open(manifest_path)
    variables = json.load(manifest)
    manifest.close()

    state_dict = {}
    for x in variables:
        torch_name = to_torch_name(x)
        if not torch_name:
            continue
        filename = variables[x]["filename"]
        byte = open(os.path.join(base_dir, chkpoint, filename), 'rb').read()
        fmt = str(int(len(byte) / struct.calcsize('f'))) + 'f'
        d = struct.unpack(fmt, byte)
        d = np.array(d, dtype=np.float32)
        shape = variables[x]["shape"]
        if len(shape) == 4:
            tpt = (2, 3, 0, 1) if 'depthwise' in filename else (3, 2, 0, 1)
            d = np.reshape(d, shape).transpose(tpt)
        state_dict[torch_name] = torch.Tensor(d)

    return state_dict


def _read_imgfile(path, width, height):
    img = cv2.imread(path)
    img = cv2.resize(img, (width, height))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    img = img * (2.0 / 255.0) - 1.0
    img = img.transpose((2, 0, 1))
    return img


def convert(model_id, model_dir, output_stride=16, image_size=513, check=True):
    checkpoint_name = MOBILENET_V1_CHECKPOINTS[model_id]
    width = image_size
    height = image_size

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    state_dict = load_variables(checkpoint_name)
    m = MobileNetV1(model_id, output_stride=output_stride)
    m.load_state_dict(state_dict)
    checkpoint_path = os.path.join(model_dir, checkpoint_name) + '.pth'
    torch.save(m.state_dict(), checkpoint_path)

    if check and os.path.exists("TryYours-Virtual-Try-On/posenet/converter/images/tennis_in_crowd.jpg"):
        # Result
        input_image = _read_imgfile("TryYours-Virtual-Try-On/posenet/converter/images/tennis_in_crowd.jpg", width, height)
        input_image = np.array(input_image, dtype=np.float32)
        input_image = input_image.reshape(1, 3, height, width)
        input_image = torch.Tensor(input_image)

        heatmaps_result, offset_result, displacement_fwd_result, displacement_bwd_result = m(input_image)

        print("Heatmaps")
        print(heatmaps_result.shape)
        print(heatmaps_result[:, 0:1, 0:1])
        print(torch.mean(heatmaps_result))