File size: 3,460 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#pragma once
#include <cstdint>
#include <c10/core/ScalarType.h>
#include <ATen/cuda/CUDAConfig.h>

// NOTE: These templates are intentionally not defined in this header,
// which aviods re-compiling them for each translation unit. If you get
// a link error, you need to add an explicit instantiation for your
// types in cub.cu

namespace at::cuda::cub {

inline int get_num_bits(uint64_t max_key) {
  int num_bits = 1;
  while (max_key > 1) {
    max_key >>= 1;
    num_bits++;
  }
  return num_bits;
}

namespace detail {

// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
template <int N> struct alignas(N) OpaqueType { char data[N]; };

template<typename key_t, int value_size>
void radix_sort_pairs_impl(

    const key_t *keys_in, key_t *keys_out,

    const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,

    int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);

}  // namespace detail

template<typename key_t, typename value_t>
void radix_sort_pairs(

    const key_t *keys_in, key_t *keys_out,

    const value_t *values_in, value_t *values_out,

    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) {
  static_assert(std::is_trivially_copyable<value_t>::value ||
                AT_ROCM_ENABLED(),  // ROCm incorrectly fails this check for vector types
                "radix_sort_pairs value type must be trivially copyable");
  // Make value type opaque, so all inputs of a certain size use the same template instantiation
  using opaque_t = detail::OpaqueType<sizeof(value_t)>;
  static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
                "This size of value_t is not instantiated. Please instantiate it in cub.cu"
                " and modify this check.");
  static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned");
  detail::radix_sort_pairs_impl(
      keys_in, keys_out,
      reinterpret_cast<const opaque_t*>(values_in),
      reinterpret_cast<opaque_t*>(values_out),
      n, descending, begin_bit, end_bit);
}

template<typename key_t>
void radix_sort_keys(

    const key_t *keys_in, key_t *keys_out,

    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8);

// NOTE: Intermediate sums will be truncated to input_t precision
template <typename input_t, typename output_t>
void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n);

template <typename scalar_t>
void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
  return inclusive_sum_truncating(input, output, n);
}

// NOTE: Sums are done is common_type<input_t, output_t>
template <typename input_t, typename output_t>
void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n);

template <typename scalar_t>
void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
  return exclusive_sum_in_common_type(input, output, n);
}

void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n);
inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) {
  return mask_exclusive_sum(
      reinterpret_cast<const uint8_t*>(mask), output_idx, n);
}

}  // namespace at::cuda::cub