Spaces:
Running
Running
File size: 436 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace at::cuda {
// Check if every tensor in a list of tensors matches the current
// device.
inline bool check_device(ArrayRef<Tensor> ts) {
if (ts.empty()) {
return true;
}
Device curDevice = Device(kCUDA, current_device());
for (const Tensor& t : ts) {
if (t.device() != curDevice) return false;
}
return true;
}
} // namespace at::cuda
|