Kano001's picture
Upload 5252 files
c61ccee verified
raw
history blame
436 Bytes
#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace at::cuda {
// Check if every tensor in a list of tensors matches the current
// device.
inline bool check_device(ArrayRef<Tensor> ts) {
if (ts.empty()) {
return true;
}
Device curDevice = Device(kCUDA, current_device());
for (const Tensor& t : ts) {
if (t.device() != curDevice) return false;
}
return true;
}
} // namespace at::cuda