File size: 2,088 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once

#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>

namespace at::native {

using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);

// reflection padding
DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);

// replication padding
DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);

namespace padding {

template <int dim>
static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {

  TORCH_CHECK(padding.size() == 2 * dim,
      "padding size is expected to be ", 2 * dim,
      ", but got: ", padding.size());

  int input_dim = input.dim();

  bool is_batch_mode = input_dim == (dim + 2);

  bool valid_batch_mode = is_batch_mode;
  bool valid_non_batch_mode = !is_batch_mode;

  if (is_batch_mode) {
    // allow batch size of 0-dim.
    for (const auto d : c10::irange(1, input_dim)) {
      valid_batch_mode = valid_batch_mode && input.size(d) != 0;
    }
  } else {
    for (const auto d : c10::irange(0, input_dim)) {
      valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
    }
  }

  // allow empty batch size but not other dimensions.
  TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
      "Expected ", dim + 1, "D or ", dim + 2,
      "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
      input.sizes());
}

} // namespace padding

} // at::native