File size: 1,992 Bytes
8d34f50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import numpy as np
import os
from util import *
import argparse
def set_requires_grad(tensor_list):
for tensor in tensor_list:
tensor.requires_grad = True
parser = argparse.ArgumentParser()
parser.add_argument(
"--path", type=str, default="", help="idname of target person")
parser.add_argument('--img_h', type=int, default=512, help='height if image')
parser.add_argument('--img_w', type=int, default=512, help='width of image')
args = parser.parse_args()
id_dir = args.path
params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
euler_angle = params_dict['euler'].cuda()
trans = params_dict['trans'].cuda() / 1000.0
focal_len = params_dict['focal'].cuda()
track_xys = torch.as_tensor(
np.load(os.path.join(id_dir, 'track_xys.npy'))).float().cuda()
num_frames = track_xys.shape[0]
point_num = track_xys.shape[1]
pts = torch.zeros((point_num, 3), dtype=torch.float32).cuda()
set_requires_grad([euler_angle, trans, pts])
cxy = torch.Tensor((args.img_w/2.0, args.img_h/2.0)).float().cuda()
optimizer_pts = torch.optim.Adam([pts], lr=1e-2)
iter_num = 500
for iter in range(iter_num):
proj_pts = forward_transform(pts.unsqueeze(0).expand(
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
loss = cal_lan_loss(proj_pts[..., :2], track_xys)
optimizer_pts.zero_grad()
loss.backward()
optimizer_pts.step()
optimizer_ba = torch.optim.Adam([pts, euler_angle, trans], lr=1e-4)
iter_num = 8000
for iter in range(iter_num):
proj_pts = forward_transform(pts.unsqueeze(0).expand(
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
loss_lan = cal_lan_loss(proj_pts[..., :2], track_xys)
loss = loss_lan
optimizer_ba.zero_grad()
loss.backward()
optimizer_ba.step()
torch.save({'euler': euler_angle.detach().cpu(),
'trans': trans.detach().cpu(),
'focal': focal_len.detach().cpu()}, os.path.join(id_dir, 'bundle_adjustment.pt'))
print('bundle adjustment params saved')
|