Spaces:
Running
Running
namespace c10 { | |
namespace detail { | |
struct IncrementRAII final { | |
public: | |
explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) { | |
_counter->fetch_add(1); | |
} | |
~IncrementRAII() { | |
_counter->fetch_sub(1); | |
} | |
private: | |
std::atomic<int32_t>* _counter; | |
C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); | |
}; | |
} // namespace detail | |
// LeftRight wait-free readers synchronization primitive | |
// https://hal.archives-ouvertes.fr/hal-01207881/document | |
// | |
// LeftRight is quite easy to use (it can make an arbitrary | |
// data structure permit wait-free reads), but it has some | |
// particular performance characteristics you should be aware | |
// of if you're deciding to use it: | |
// | |
// - Reads still incur an atomic write (this is how LeftRight | |
// keeps track of how long it needs to keep around the old | |
// data structure) | |
// | |
// - Writes get executed twice, to keep both the left and right | |
// versions up to date. So if your write is expensive or | |
// nondeterministic, this is also an inappropriate structure | |
// | |
// LeftRight is used fairly rarely in PyTorch's codebase. If you | |
// are still not sure if you need it or not, consult your local | |
// C++ expert. | |
// | |
template <class T> | |
class LeftRight final { | |
public: | |
template <class... Args> | |
explicit LeftRight(const Args&... args) | |
: _counters{{{0}, {0}}}, | |
_foregroundCounterIndex(0), | |
_foregroundDataIndex(0), | |
_data{{T{args...}, T{args...}}}, | |
_writeMutex() {} | |
// Copying and moving would not be threadsafe. | |
// Needs more thought and careful design to make that work. | |
LeftRight(const LeftRight&) = delete; | |
LeftRight(LeftRight&&) noexcept = delete; | |
LeftRight& operator=(const LeftRight&) = delete; | |
LeftRight& operator=(LeftRight&&) noexcept = delete; | |
~LeftRight() { | |
// wait until any potentially running writers are finished | |
{ std::unique_lock<std::mutex> lock(_writeMutex); } | |
// wait until any potentially running readers are finished | |
while (_counters[0].load() != 0 || _counters[1].load() != 0) { | |
std::this_thread::yield(); | |
} | |
} | |
template <typename F> | |
auto read(F&& readFunc) const { | |
detail::IncrementRAII _increment_counter( | |
&_counters[_foregroundCounterIndex.load()]); | |
return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]); | |
} | |
// Throwing an exception in writeFunc is ok but causes the state to be either | |
// the old or the new state, depending on if the first or the second call to | |
// writeFunc threw. | |
template <typename F> | |
auto write(F&& writeFunc) { | |
std::unique_lock<std::mutex> lock(_writeMutex); | |
return _write(std::forward<F>(writeFunc)); | |
} | |
private: | |
template <class F> | |
auto _write(const F& writeFunc) { | |
/* | |
* Assume, A is in background and B in foreground. In simplified terms, we | |
* want to do the following: | |
* 1. Write to A (old background) | |
* 2. Switch A/B | |
* 3. Write to B (new background) | |
* | |
* More detailed algorithm (explanations on why this is important are below | |
* in code): | |
* 1. Write to A | |
* 2. Switch A/B data pointers | |
* 3. Wait until A counter is zero | |
* 4. Switch A/B counters | |
* 5. Wait until B counter is zero | |
* 6. Write to B | |
*/ | |
auto localDataIndex = _foregroundDataIndex.load(); | |
// 1. Write to A | |
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); | |
// 2. Switch A/B data pointers | |
localDataIndex = localDataIndex ^ 1; | |
_foregroundDataIndex = localDataIndex; | |
/* | |
* 3. Wait until A counter is zero | |
* | |
* In the previous write run, A was foreground and B was background. | |
* There was a time after switching _foregroundDataIndex (B to foreground) | |
* and before switching _foregroundCounterIndex, in which new readers could | |
* have read B but incremented A's counter. | |
* | |
* In this current run, we just switched _foregroundDataIndex (A back to | |
* foreground), but before writing to the new background B, we have to make | |
* sure A's counter was zero briefly, so all these old readers are gone. | |
*/ | |
auto localCounterIndex = _foregroundCounterIndex.load(); | |
_waitForBackgroundCounterToBeZero(localCounterIndex); | |
/* | |
* 4. Switch A/B counters | |
* | |
* Now that we know all readers on B are really gone, we can switch the | |
* counters and have new readers increment A's counter again, which is the | |
* correct counter since they're reading A. | |
*/ | |
localCounterIndex = localCounterIndex ^ 1; | |
_foregroundCounterIndex = localCounterIndex; | |
/* | |
* 5. Wait until B counter is zero | |
* | |
* This waits for all the readers on B that came in while both data and | |
* counter for B was in foreground, i.e. normal readers that happened | |
* outside of that brief gap between switching data and counter. | |
*/ | |
_waitForBackgroundCounterToBeZero(localCounterIndex); | |
// 6. Write to B | |
return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); | |
} | |
template <class F> | |
auto _callWriteFuncOnBackgroundInstance( | |
const F& writeFunc, | |
uint8_t localDataIndex) { | |
try { | |
return writeFunc(_data[localDataIndex ^ 1]); | |
} catch (...) { | |
// recover invariant by copying from the foreground instance | |
_data[localDataIndex ^ 1] = _data[localDataIndex]; | |
// rethrow | |
throw; | |
} | |
} | |
void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { | |
while (_counters[counterIndex ^ 1].load() != 0) { | |
std::this_thread::yield(); | |
} | |
} | |
mutable std::array<std::atomic<int32_t>, 2> _counters; | |
std::atomic<uint8_t> _foregroundCounterIndex; | |
std::atomic<uint8_t> _foregroundDataIndex; | |
std::array<T, 2> _data; | |
std::mutex _writeMutex; | |
}; | |
// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a | |
// read-write lock to protect T (data). | |
template <class T> | |
class RWSafeLeftRightWrapper final { | |
public: | |
template <class... Args> | |
explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} | |
// RWSafeLeftRightWrapper is not copyable or moveable since LeftRight | |
// is not copyable or moveable. | |
RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; | |
RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; | |
RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; | |
RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; | |
template <typename F> | |
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) | |
auto read(F&& readFunc) const { | |
return data_.withLock( | |
[&readFunc](T const& data) { return std::forward<F>(readFunc)(data); }); | |
} | |
template <typename F> | |
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) | |
auto write(F&& writeFunc) { | |
return data_.withLock( | |
[&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); }); | |
} | |
private: | |
c10::Synchronized<T> data_; | |
}; | |
} // namespace c10 | |