File size: 3,366 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#pragma once

#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Pool.h>

namespace at::native {

static void check_max_pool1d(

    const Tensor& self,

    IntArrayRef kernel_size,

    IntArrayRef stride,

    IntArrayRef padding,

    IntArrayRef dilation,

    bool ceil_mode) {

  TORCH_CHECK(
      self.dim() == 2 || self.dim() == 3,
      "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
  TORCH_CHECK(
      kernel_size.size() == 1,
      "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
      kernel_size.size());
  TORCH_CHECK(
      stride.empty() || stride.size() == 1,
      "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
      stride.size());
  TORCH_CHECK(
      padding.size() == 1,
      "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
      padding.size());
  TORCH_CHECK(
      dilation.size() == 1,
      "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
      dilation.size());

  // If stride=None then set it to kernel_size
  if (stride.empty()) {
    stride = kernel_size;
  }

  TORCH_CHECK(
      kernel_size[0] > 0,
      "max_pool1d() kernel_size must be greater than zero, but got ",
      kernel_size[0]);
  TORCH_CHECK(
      stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
  TORCH_CHECK(
      padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
  TORCH_CHECK(
      padding[0] <= kernel_size[0] / 2,
      "max_pool1d() padding should be at most half of kernel size, but got padding=",
      padding[0],
      " and kernel_size=",
      kernel_size[0]);
  TORCH_CHECK(
      dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);

  const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
  TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
}

// TODO(Heitor) Template by dimension
struct PoolingParams1D {
  int64_t NB; // Number of batches
  int64_t NC; // Number of channels
  int64_t IW; // Input width
  int64_t OW; // Output width
  int64_t KW; // Kernel width
  int64_t SJ; // Column stride
  int64_t PJ; // Column padding
  int64_t DJ; // Column dilation

  // Return index of input element for the given kernel and output index
  inline int64_t index(int64_t kj, int64_t oj) const {
    return oj * SJ + kj * DJ - PJ;
  }

  // Return index of first output within bounds for this kernel index
  inline int64_t valid_output_start(int64_t kj) const {
    int64_t ij = index(kj, 0);;
    return ij < 0 ? at::divup(-ij, SJ) : 0;
  }

  // Return index one past last output within bounds for this kernel index
  inline int64_t valid_output_end(int64_t kj) const {
    int64_t ij = index(kj, OW - 1);
    return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
  }
};

using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);

DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);

} // namespace at::native