File size: 1,992 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import os
from util import *
import argparse


def set_requires_grad(tensor_list):
    for tensor in tensor_list:
        tensor.requires_grad = True


parser = argparse.ArgumentParser()

parser.add_argument(
    "--path", type=str, default="", help="idname of target person")
parser.add_argument('--img_h', type=int, default=512, help='height if image')
parser.add_argument('--img_w', type=int, default=512, help='width of image')
args = parser.parse_args()
id_dir = args.path

params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
euler_angle = params_dict['euler'].cuda()
trans = params_dict['trans'].cuda() / 1000.0
focal_len = params_dict['focal'].cuda()

track_xys = torch.as_tensor(
    np.load(os.path.join(id_dir, 'track_xys.npy'))).float().cuda()
num_frames = track_xys.shape[0]
point_num = track_xys.shape[1]

pts = torch.zeros((point_num, 3), dtype=torch.float32).cuda()
set_requires_grad([euler_angle, trans, pts])

cxy = torch.Tensor((args.img_w/2.0, args.img_h/2.0)).float().cuda()

optimizer_pts = torch.optim.Adam([pts], lr=1e-2)
iter_num = 500
for iter in range(iter_num):
    proj_pts = forward_transform(pts.unsqueeze(0).expand(
        num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
    loss = cal_lan_loss(proj_pts[..., :2], track_xys)
    optimizer_pts.zero_grad()
    loss.backward()
    optimizer_pts.step()


optimizer_ba = torch.optim.Adam([pts, euler_angle, trans], lr=1e-4)


iter_num = 8000
for iter in range(iter_num):
    proj_pts = forward_transform(pts.unsqueeze(0).expand(
        num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
    loss_lan = cal_lan_loss(proj_pts[..., :2], track_xys)
    loss = loss_lan
    optimizer_ba.zero_grad()
    loss.backward()
    optimizer_ba.step()

torch.save({'euler': euler_angle.detach().cpu(),
            'trans': trans.detach().cpu(),
            'focal': focal_len.detach().cpu()}, os.path.join(id_dir, 'bundle_adjustment.pt'))
print('bundle adjustment params saved')