Spaces:
Running
Running
File size: 7,295 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
#include <c10/macros/Macros.h>
#include <c10/util/Synchronized.h>
#include <array>
#include <atomic>
#include <mutex>
#include <thread>
namespace c10 {
namespace detail {
struct IncrementRAII final {
public:
explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
_counter->fetch_add(1);
}
~IncrementRAII() {
_counter->fetch_sub(1);
}
private:
std::atomic<int32_t>* _counter;
C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
};
} // namespace detail
// LeftRight wait-free readers synchronization primitive
// https://hal.archives-ouvertes.fr/hal-01207881/document
//
// LeftRight is quite easy to use (it can make an arbitrary
// data structure permit wait-free reads), but it has some
// particular performance characteristics you should be aware
// of if you're deciding to use it:
//
// - Reads still incur an atomic write (this is how LeftRight
// keeps track of how long it needs to keep around the old
// data structure)
//
// - Writes get executed twice, to keep both the left and right
// versions up to date. So if your write is expensive or
// nondeterministic, this is also an inappropriate structure
//
// LeftRight is used fairly rarely in PyTorch's codebase. If you
// are still not sure if you need it or not, consult your local
// C++ expert.
//
template <class T>
class LeftRight final {
public:
template <class... Args>
explicit LeftRight(const Args&... args)
: _counters{{{0}, {0}}},
_foregroundCounterIndex(0),
_foregroundDataIndex(0),
_data{{T{args...}, T{args...}}},
_writeMutex() {}
// Copying and moving would not be threadsafe.
// Needs more thought and careful design to make that work.
LeftRight(const LeftRight&) = delete;
LeftRight(LeftRight&&) noexcept = delete;
LeftRight& operator=(const LeftRight&) = delete;
LeftRight& operator=(LeftRight&&) noexcept = delete;
~LeftRight() {
// wait until any potentially running writers are finished
{ std::unique_lock<std::mutex> lock(_writeMutex); }
// wait until any potentially running readers are finished
while (_counters[0].load() != 0 || _counters[1].load() != 0) {
std::this_thread::yield();
}
}
template <typename F>
auto read(F&& readFunc) const {
detail::IncrementRAII _increment_counter(
&_counters[_foregroundCounterIndex.load()]);
return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]);
}
// Throwing an exception in writeFunc is ok but causes the state to be either
// the old or the new state, depending on if the first or the second call to
// writeFunc threw.
template <typename F>
auto write(F&& writeFunc) {
std::unique_lock<std::mutex> lock(_writeMutex);
return _write(std::forward<F>(writeFunc));
}
private:
template <class F>
auto _write(const F& writeFunc) {
/*
* Assume, A is in background and B in foreground. In simplified terms, we
* want to do the following:
* 1. Write to A (old background)
* 2. Switch A/B
* 3. Write to B (new background)
*
* More detailed algorithm (explanations on why this is important are below
* in code):
* 1. Write to A
* 2. Switch A/B data pointers
* 3. Wait until A counter is zero
* 4. Switch A/B counters
* 5. Wait until B counter is zero
* 6. Write to B
*/
auto localDataIndex = _foregroundDataIndex.load();
// 1. Write to A
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
// 2. Switch A/B data pointers
localDataIndex = localDataIndex ^ 1;
_foregroundDataIndex = localDataIndex;
/*
* 3. Wait until A counter is zero
*
* In the previous write run, A was foreground and B was background.
* There was a time after switching _foregroundDataIndex (B to foreground)
* and before switching _foregroundCounterIndex, in which new readers could
* have read B but incremented A's counter.
*
* In this current run, we just switched _foregroundDataIndex (A back to
* foreground), but before writing to the new background B, we have to make
* sure A's counter was zero briefly, so all these old readers are gone.
*/
auto localCounterIndex = _foregroundCounterIndex.load();
_waitForBackgroundCounterToBeZero(localCounterIndex);
/*
* 4. Switch A/B counters
*
* Now that we know all readers on B are really gone, we can switch the
* counters and have new readers increment A's counter again, which is the
* correct counter since they're reading A.
*/
localCounterIndex = localCounterIndex ^ 1;
_foregroundCounterIndex = localCounterIndex;
/*
* 5. Wait until B counter is zero
*
* This waits for all the readers on B that came in while both data and
* counter for B was in foreground, i.e. normal readers that happened
* outside of that brief gap between switching data and counter.
*/
_waitForBackgroundCounterToBeZero(localCounterIndex);
// 6. Write to B
return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
}
template <class F>
auto _callWriteFuncOnBackgroundInstance(
const F& writeFunc,
uint8_t localDataIndex) {
try {
return writeFunc(_data[localDataIndex ^ 1]);
} catch (...) {
// recover invariant by copying from the foreground instance
_data[localDataIndex ^ 1] = _data[localDataIndex];
// rethrow
throw;
}
}
void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
while (_counters[counterIndex ^ 1].load() != 0) {
std::this_thread::yield();
}
}
mutable std::array<std::atomic<int32_t>, 2> _counters;
std::atomic<uint8_t> _foregroundCounterIndex;
std::atomic<uint8_t> _foregroundDataIndex;
std::array<T, 2> _data;
std::mutex _writeMutex;
};
// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
// read-write lock to protect T (data).
template <class T>
class RWSafeLeftRightWrapper final {
public:
template <class... Args>
explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
// RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
// is not copyable or moveable.
RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
template <typename F>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
auto read(F&& readFunc) const {
return data_.withLock(
[&readFunc](T const& data) { return std::forward<F>(readFunc)(data); });
}
template <typename F>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
auto write(F&& writeFunc) {
return data_.withLock(
[&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); });
}
private:
c10::Synchronized<T> data_;
};
} // namespace c10
|