wizzseen's picture
Update TryYours-Virtual-Try-On/posenet/converter/tfjs2pytorch.py
61990d5 verified
import json
import struct
import cv2
import numpy as np
import os
import tempfile
import torch
from posenet import MobileNetV1, MOBILENET_V1_CHECKPOINTS
BASE_DIR = os.path.join(tempfile.gettempdir(), '_posenet_weights')
def to_torch_name(tf_name):
tf_name = tf_name.lower()
tf_split = tf_name.split('/')
tf_layer_split = tf_split[1].split('_')
tf_variable_type = tf_split[2]
if tf_variable_type == 'weights' or tf_variable_type == 'depthwise_weights':
variable_postfix = '.weight'
elif tf_variable_type == 'biases':
variable_postfix = '.bias'
else:
variable_postfix = ''
if tf_layer_split[0] == 'conv2d':
torch_name = 'features.conv' + tf_layer_split[1]
if len(tf_layer_split) > 2:
torch_name += '.' + tf_layer_split[2]
else:
torch_name += '.conv'
torch_name += variable_postfix
else:
if tf_layer_split[0] in ['offset', 'displacement', 'heatmap'] and tf_layer_split[-1] == '2':
torch_name = '_'.join(tf_layer_split[:-1])
torch_name += variable_postfix
else:
torch_name = ''
return torch_name
def load_variables(chkpoint, base_dir=BASE_DIR):
manifest_path = os.path.join(base_dir, chkpoint, "manifest.json")
if not os.path.exists(manifest_path):
print('Weights for checkpoint %s are not downloaded. Downloading to %s ...' % (chkpoint, base_dir))
from posenet.converter.wget import download
download(chkpoint, base_dir)
assert os.path.exists(manifest_path)
manifest = open(manifest_path)
variables = json.load(manifest)
manifest.close()
state_dict = {}
for x in variables:
torch_name = to_torch_name(x)
if not torch_name:
continue
filename = variables[x]["filename"]
byte = open(os.path.join(base_dir, chkpoint, filename), 'rb').read()
fmt = str(int(len(byte) / struct.calcsize('f'))) + 'f'
d = struct.unpack(fmt, byte)
d = np.array(d, dtype=np.float32)
shape = variables[x]["shape"]
if len(shape) == 4:
tpt = (2, 3, 0, 1) if 'depthwise' in filename else (3, 2, 0, 1)
d = np.reshape(d, shape).transpose(tpt)
state_dict[torch_name] = torch.Tensor(d)
return state_dict
def _read_imgfile(path, width, height):
img = cv2.imread(path)
img = cv2.resize(img, (width, height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
img = img * (2.0 / 255.0) - 1.0
img = img.transpose((2, 0, 1))
return img
def convert(model_id, model_dir, output_stride=16, image_size=513, check=True):
checkpoint_name = MOBILENET_V1_CHECKPOINTS[model_id]
width = image_size
height = image_size
if not os.path.exists(model_dir):
os.makedirs(model_dir)
state_dict = load_variables(checkpoint_name)
m = MobileNetV1(model_id, output_stride=output_stride)
m.load_state_dict(state_dict)
checkpoint_path = os.path.join(model_dir, checkpoint_name) + '.pth'
torch.save(m.state_dict(), checkpoint_path)
if check and os.path.exists("TryYours-Virtual-Try-On/posenet/converter/images/tennis_in_crowd.jpg"):
# Result
input_image = _read_imgfile("TryYours-Virtual-Try-On/posenet/converter/images/tennis_in_crowd.jpg", width, height)
input_image = np.array(input_image, dtype=np.float32)
input_image = input_image.reshape(1, 3, height, width)
input_image = torch.Tensor(input_image)
heatmaps_result, offset_result, displacement_fwd_result, displacement_bwd_result = m(input_image)
print("Heatmaps")
print(heatmaps_result.shape)
print(heatmaps_result[:, 0:1, 0:1])
print(torch.mean(heatmaps_result))