#include #include #include #include #include #include #include std::vector cuda_ba( torch::Tensor poses, torch::Tensor patches, torch::Tensor intrinsics, torch::Tensor target, torch::Tensor weight, torch::Tensor lmbda, torch::Tensor ii, torch::Tensor jj, torch::Tensor kk, const int PPF, int t0, int t1, int iterations, bool eff_impl); torch::Tensor cuda_reproject( torch::Tensor poses, torch::Tensor patches, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj, torch::Tensor kk); std::vector ba( torch::Tensor poses, torch::Tensor patches, torch::Tensor intrinsics, torch::Tensor target, torch::Tensor weight, torch::Tensor lmbda, torch::Tensor ii, torch::Tensor jj, torch::Tensor kk, int PPF, int t0, int t1, int iterations, bool eff_impl) { return cuda_ba(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, PPF, t0, t1, iterations, eff_impl); } torch::Tensor reproject( torch::Tensor poses, torch::Tensor patches, torch::Tensor intrinsics, torch::Tensor ii, torch::Tensor jj, torch::Tensor kk) { return cuda_reproject(poses, patches, intrinsics, ii, jj, kk); } std::vector neighbors(torch::Tensor ii, torch::Tensor jj) { auto tup = torch::_unique(ii, true, true); torch::Tensor uniq = std::get<0>(tup).to(torch::kCPU); torch::Tensor perm = std::get<1>(tup).to(torch::kCPU); jj = jj.to(torch::kCPU); auto jj_accessor = jj.accessor(); auto perm_accessor = perm.accessor(); std::vector> index(uniq.size(0)); for (int i=0; i < ii.size(0); i++) { index[perm_accessor[i]].push_back(i); } auto opts = torch::TensorOptions().dtype(torch::kInt64); torch::Tensor ix = torch::empty({ii.size(0)}, opts); torch::Tensor jx = torch::empty({ii.size(0)}, opts); auto ix_accessor = ix.accessor(); auto jx_accessor = jx.accessor(); for (int i=0; i& idx = index[i]; std::stable_sort(idx.begin(), idx.end(), [&jj_accessor](size_t i, size_t j) {return jj_accessor[i] < jj_accessor[j];}); for (int i=0; i < idx.size(); i++) { ix_accessor[idx[i]] = (i > 0) ? idx[i-1] : -1; jx_accessor[idx[i]] = (i < idx.size() - 1) ? idx[i+1] : -1; } } ix = ix.to(torch::kCUDA); jx = jx.to(torch::kCUDA); return {ix, jx}; } typedef Eigen::SparseMatrix SpMat; typedef Eigen::Triplet T; Eigen::VectorXd solve(const SpMat &A, const Eigen::VectorXd &b, int freen){ if (freen < 0){ const Eigen::SimplicialCholesky chol(A); return chol.solve(b); // n x 1 } const SpMat A_sub = A.topLeftCorner(freen, freen); const Eigen::VectorXd b_sub = b.topRows(freen); const Eigen::VectorXd delta = solve(A_sub, b_sub, -7); Eigen::VectorXd delta2(b.rows()); delta2.setZero(); delta2.topRows(freen) = delta; return delta2; } std::vector solve_system(torch::Tensor J_Ginv_i, torch::Tensor J_Ginv_j, torch::Tensor ii, torch::Tensor jj, torch::Tensor res, float ep, float lm, int freen) { const torch::Device device = res.device(); J_Ginv_i = J_Ginv_i.to(torch::kCPU); J_Ginv_j = J_Ginv_j.to(torch::kCPU); ii = ii.to(torch::kCPU); jj = jj.to(torch::kCPU); res = res.clone().to(torch::kCPU); const int r = res.size(0); const int n = std::max(ii.max().item(), jj.max().item()) + 1; res.resize_({r*7}); float *res_ptr = res.data_ptr(); Eigen::Map v(res_ptr, r*7); SpMat J(r*7, n*7); std::vector tripletList; tripletList.reserve(r*7*7*2); auto ii_acc = ii.accessor(); auto jj_acc = jj.accessor(); auto J_Ginv_i_acc = J_Ginv_i.accessor(); auto J_Ginv_j_acc = J_Ginv_j.accessor(); for (int x=0; x()); SpMat A = Jt * J; A.diagonal() += (A.diagonal() * lm); A.diagonal().array() += ep; Eigen::VectorXf delta = solve(A, b, freen*7).cast(); torch::Tensor delta_tensor = torch::from_blob(delta.data(), {n*7}).clone().to(device); delta_tensor.resize_({n, 7}); return {delta_tensor}; Eigen::Matrix dense_J(J.cast()); torch::Tensor dense_J_tensor = torch::from_blob(dense_J.data(), {r*7, n*7}).clone().to(device); dense_J_tensor.resize_({r, 7, n, 7}); return {delta_tensor, dense_J_tensor}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &ba, "BA forward operator"); m.def("neighbors", &neighbors, "temporal neighboor indicies"); m.def("reproject", &reproject, "temporal neighboor indicies"); m.def("solve_system", &solve_system, "temporal neighboor indicies"); }