File size: 1,803 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>

namespace at {
namespace native {

inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
  TORCH_CHECK(self.dim() >= 3,
              "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
              self.dim(), " dimension(s)");
  TORCH_CHECK(upscale_factor > 0,
              "pixel_shuffle expects a positive upscale_factor, but got ",
              upscale_factor);
  int64_t c = self.size(-3);
  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
  TORCH_CHECK(c % upscale_factor_squared == 0,
              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
              "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
}

inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
  TORCH_CHECK(
      self.dim() >= 3,
      "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
      self.dim(),
      " dimension(s)");
  TORCH_CHECK(
      downscale_factor > 0,
      "pixel_unshuffle expects a positive downscale_factor, but got ",
      downscale_factor);
  int64_t h = self.size(-2);
  int64_t w = self.size(-1);
  TORCH_CHECK(
      h % downscale_factor == 0,
      "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
      h,
      " is not divisible by ",
      downscale_factor);
  TORCH_CHECK(
      w % downscale_factor == 0,
      "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
      w,
      " is not divisible by ",
      downscale_factor);
}

}} // namespace at::native