File size: 584 Bytes
8a42f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <torch/extension.h>
#include <torch/csrc/utils/tensor_flatten.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h

at::Tensor flatten(std::vector<at::Tensor> tensors)
{
  return torch::utils::flatten_dense_tensors(tensors);
}

std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
{
  return torch::utils::unflatten_dense_tensors(flat, tensors);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("flatten", &flatten, "Flatten dense tensors");
  m.def("unflatten", &unflatten, "Unflatten dense tensors");
}