Techt3o's picture
1c0fe995ac3bf6c4bc83a727a73c46ab2d045729fb0abd53c4c78cd2b8282877
20ae9ff verified
raw
history blame
759 Bytes
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
const auto mdtype = torch::dtype(torch::kFloat32).device(torch::kCUDA);
typedef float mtype;
class EfficentE
{
private:
torch::Tensor block_index_tensor, index_tensor, patch_to_ku;
const int t0;
public:
const int ppf;
torch::Tensor E_lookup, ij_xself;
EfficentE(const torch::Tensor &ii, const torch::Tensor &jj, const torch::Tensor &ku, const int patches_per_frame, const int t0);
EfficentE();
torch::Tensor computeEQEt(const int N, const torch::Tensor &Q) const;
torch::Tensor computeEv(const int N, const torch::Tensor &vec) const;
torch::Tensor computeEtv(const int M, const torch::Tensor &vec) const;
};