Async++ unknown
Async (co_await/co_return) code for C++
Loading...
Searching...
No Matches
stop_token.h
Go to the documentation of this file.
1
5#pragma once
6#ifndef ASYNCPP_FORCE_CUSTOM_STOP_TOKEN
8#define ASYNCPP_FORCE_CUSTOM_STOP_TOKEN 0
9#endif
10
11#include <version>
12#if defined(_LIBCPP_VERSION) || ASYNCPP_FORCE_CUSTOM_STOP_TOKEN
13#include <atomic>
14#include <thread>
15#else
16#include <stop_token>
17#endif
18
19namespace asyncpp {
20#if defined(_LIBCPP_VERSION) || ASYNCPP_FORCE_CUSTOM_STOP_TOKEN
21 struct nostopstate_t {
22 explicit nostopstate_t() = default;
23 };
24 inline constexpr nostopstate_t nostopstate{};
25
26 class stop_source;
27
28 class stop_token {
29 public:
30 stop_token() noexcept = default;
31
32 stop_token(const stop_token&) noexcept = default;
33 stop_token(stop_token&&) noexcept = default;
34
35 ~stop_token() = default;
36
37 stop_token& operator=(const stop_token&) noexcept = default;
38 stop_token& operator=(stop_token&&) noexcept = default;
39
40 [[nodiscard]] bool stop_possible() const noexcept {
41 return static_cast<bool>(m_state) && m_state->stop_possible();
42 }
43
44 [[nodiscard]] bool stop_requested() const noexcept {
45 return static_cast<bool>(m_state) && m_state->stop_requested();
46 }
47
48 void swap(stop_token& rhs) noexcept { m_state.swap(rhs.m_state); }
49
50 [[nodiscard]] friend bool operator==(const stop_token& lhs, const stop_token& rhs) {
51 return lhs.m_state == rhs.m_state;
52 }
53
54 friend void swap(stop_token& lhs, stop_token& rhs) noexcept { lhs.swap(rhs); }
55
56 private:
57 friend class stop_source;
58 template<typename _Callback>
59 friend class stop_callback;
60
61 static void yield() noexcept {
62#if defined __i386__ || defined __x86_64__
63 __builtin_ia32_pause();
64#endif
65 std::this_thread::yield();
66 }
67
68 struct binary_semaphore {
69 explicit binary_semaphore(int initial) : m_counter(initial > 0) {}
70
71 void release() { m_counter.fetch_add(1, std::memory_order::release); }
72
73 void acquire() {
74 int old = 1;
75 while (
76 !m_counter.compare_exchange_weak(old, 0, std::memory_order::acquire, std::memory_order::relaxed)) {
77 old = 1;
78 yield();
79 }
80 }
81
82 std::atomic<int> m_counter;
83 };
84
85 struct stop_cb_node_t {
86 using cb_fn_t = void(stop_cb_node_t*) noexcept;
87 cb_fn_t* m_callback;
88 stop_cb_node_t* m_prev = nullptr;
89 stop_cb_node_t* m_next = nullptr;
90 bool* m_destroyed = nullptr;
91 binary_semaphore m_done{0};
92
93 explicit stop_cb_node_t(cb_fn_t* cb) : m_callback(cb) {}
94
95 void run() noexcept { m_callback(this); }
96 };
97
98 class stop_state_t {
99 using value_type = uint32_t;
100 static constexpr value_type mask_stop_requested_bit = 1;
101 static constexpr value_type mask_locked_bit = 2;
102 static constexpr value_type mask_ssrc_counter_inc = 4;
103
104 std::atomic<value_type> m_owners{1};
105 std::atomic<value_type> m_value{mask_ssrc_counter_inc};
106 stop_cb_node_t* m_head = nullptr;
107 std::thread::id m_requester;
108
109 public:
110 stop_state_t() = default;
111
112 bool stop_possible() noexcept { return m_value.load(std::memory_order::acquire) & ~mask_locked_bit; }
113
114 bool stop_requested() noexcept {
115 return m_value.load(std::memory_order::acquire) & mask_stop_requested_bit;
116 }
117
118 void add_owner() noexcept { m_owners.fetch_add(1, std::memory_order::relaxed); }
119
120 void release_ownership() noexcept {
121 if (m_owners.fetch_sub(1, std::memory_order::acq_rel) == 1) delete this;
122 }
123
124 void add_ssrc() noexcept { m_value.fetch_add(mask_ssrc_counter_inc, std::memory_order::relaxed); }
125
126 void sub_ssrc() noexcept { m_value.fetch_sub(mask_ssrc_counter_inc, std::memory_order::release); }
127
128 bool request_stop() noexcept {
129 auto old = m_value.load(std::memory_order::acquire);
130 do {
131 if (old & mask_stop_requested_bit) return false;
132 } while (
133 !try_lock(old, mask_stop_requested_bit, std::memory_order::acq_rel, std::memory_order::acquire));
134
135 m_requester = std::this_thread::get_id();
136
137 while (m_head) {
138 bool is_last_cb{true};
139 stop_cb_node_t* cb = m_head;
140 m_head = m_head->m_next;
141 if (m_head) {
142 m_head->m_prev = nullptr;
143 is_last_cb = false;
144 }
145
146 unlock();
147
148 bool is_destroyed = false;
149 cb->m_destroyed = &is_destroyed;
150
151 cb->run();
152
153 if (!is_destroyed) {
154 cb->m_destroyed = nullptr;
155 cb->m_done.release();
156 }
157
158 if (is_last_cb) return true;
159
160 lock();
161 }
162
163 unlock();
164 return true;
165 }
166
167 bool register_callback(stop_cb_node_t* cb) noexcept {
168 auto old = m_value.load(std::memory_order::acquire);
169 do {
170 if (old & mask_stop_requested_bit) {
171 cb->run();
172 return false;
173 }
174
175 if (old < mask_ssrc_counter_inc) return false;
176 } while (!try_lock(old, 0, std::memory_order::acquire, std::memory_order::acquire));
177
178 cb->m_next = m_head;
179 if (m_head) { m_head->m_prev = cb; }
180 m_head = cb;
181 unlock();
182 return true;
183 }
184
185 void remove_callback(stop_cb_node_t* cb) {
186 lock();
187
188 if (cb == m_head) {
189 m_head = m_head->m_next;
190 if (m_head) m_head->m_prev = nullptr;
191 unlock();
192 return;
193 } else if (cb->m_prev) {
194 cb->m_prev->m_next = cb->m_next;
195 if (cb->m_next) cb->m_next->m_prev = cb->m_prev;
196 unlock();
197 return;
198 }
199
200 unlock();
201
202 if (!(m_requester == std::this_thread::get_id())) {
203 cb->m_done.acquire();
204 return;
205 }
206
207 if (cb->m_destroyed) *cb->m_destroyed = true;
208 }
209
210 private:
211 void lock() noexcept {
212 auto old = m_value.load(std::memory_order::relaxed);
213 while (!try_lock(old, 0, std::memory_order::acquire, std::memory_order::relaxed)) {}
214 }
215
216 void unlock() noexcept { m_value.fetch_sub(mask_locked_bit, std::memory_order::release); }
217
218 bool try_lock(value_type& curval, value_type newbits, std::memory_order success,
219 std::memory_order failure) noexcept {
220 if (curval & mask_locked_bit) {
221 yield();
222 curval = m_value.load(failure);
223 return false;
224 }
225 newbits |= mask_locked_bit;
226 return m_value.compare_exchange_weak(curval, curval | newbits, success, failure);
227 }
228 };
229
230 struct stop_state_ref {
231 stop_state_ref() = default;
232
233 explicit stop_state_ref(const stop_source&) : m_ptr(new stop_state_t()) {}
234
235 stop_state_ref(const stop_state_ref& other) noexcept : m_ptr(other.m_ptr) {
236 if (m_ptr) m_ptr->add_owner();
237 }
238
239 stop_state_ref(stop_state_ref&& other) noexcept : m_ptr(other.m_ptr) { other.m_ptr = nullptr; }
240
241 stop_state_ref& operator=(const stop_state_ref& other) noexcept {
242 if (auto ptr = other.m_ptr; ptr != m_ptr) {
243 if (ptr) ptr->add_owner();
244 if (m_ptr) m_ptr->release_ownership();
245 m_ptr = ptr;
246 }
247 return *this;
248 }
249
250 stop_state_ref& operator=(stop_state_ref&& other) noexcept {
251 stop_state_ref(std::move(other)).swap(*this);
252 return *this;
253 }
254
255 ~stop_state_ref() {
256 if (m_ptr) m_ptr->release_ownership();
257 }
258
259 void swap(stop_state_ref& other) noexcept { std::swap(m_ptr, other.m_ptr); }
260
261 explicit operator bool() const noexcept { return m_ptr != nullptr; }
262
263 stop_state_t* operator->() const noexcept { return m_ptr; }
264
265#if __cpp_impl_three_way_comparison >= 201907L
266 friend bool operator==(const stop_state_ref&, const stop_state_ref&) = default;
267#else
268 friend bool operator==(const stop_state_ref& lhs, const stop_state_ref& rhs) noexcept {
269 return lhs.m_ptr == rhs.m_ptr;
270 }
271
272 friend bool operator!=(const stop_state_ref& lhs, const stop_state_ref& rhs) noexcept {
273 return lhs.m_ptr != rhs.m_ptr;
274 }
275#endif
276
277 private:
278 stop_state_t* m_ptr = nullptr;
279 };
280
281 stop_state_ref m_state;
282
283 explicit stop_token(const stop_state_ref& state) noexcept : m_state{state} {}
284 };
285
287 class stop_source {
288 public:
289 stop_source() : m_state(*this) {}
290
291 explicit stop_source(nostopstate_t) noexcept {}
292
293 stop_source(const stop_source& other) noexcept : m_state(other.m_state) {
294 if (m_state) m_state->add_ssrc();
295 }
296
297 stop_source(stop_source&&) noexcept = default;
298
299 stop_source& operator=(const stop_source& other) noexcept {
300 if (m_state != other.m_state) {
301 stop_source sink(std::move(*this));
302 m_state = other.m_state;
303 if (m_state) m_state->add_ssrc();
304 }
305 return *this;
306 }
307
308 stop_source& operator=(stop_source&&) noexcept = default;
309
310 ~stop_source() {
311 if (m_state) m_state->sub_ssrc();
312 }
313
314 [[nodiscard]] bool stop_possible() const noexcept { return static_cast<bool>(m_state); }
315
316 [[nodiscard]] bool stop_requested() const noexcept {
317 return static_cast<bool>(m_state) && m_state->stop_requested();
318 }
319
320 bool request_stop() const noexcept {
321 if (stop_possible()) return m_state->request_stop();
322 return false;
323 }
324
325 [[nodiscard]] stop_token get_token() const noexcept { return stop_token{m_state}; }
326
327 void swap(stop_source& other) noexcept { m_state.swap(other.m_state); }
328
329 [[nodiscard]] friend bool operator==(const stop_source& a, const stop_source& b) noexcept {
330 return a.m_state == b.m_state;
331 }
332
333 friend void swap(stop_source& lhs, stop_source& rhs) noexcept { lhs.swap(rhs); }
334
335 private:
336 stop_token::stop_state_ref m_state;
337 };
338
340 template<typename Callback>
341 class [[nodiscard]] stop_callback {
342 static_assert(std::is_nothrow_destructible_v<Callback>);
343 static_assert(std::is_invocable_v<Callback>);
344
345 public:
346 using callback_type = Callback;
347
348 template<typename Cb>
349 requires(std::is_constructible_v<Callback, Cb>)
350 explicit stop_callback(const stop_token& token, Cb&& cb) noexcept(std::is_nothrow_constructible_v<Callback, Cb>)
351 : m_cb(std::forward<Cb>(cb)) {
352 if (auto state = token.m_state) {
353 if (state->register_callback(&m_cb)) m_state.swap(state);
354 }
355 }
356
357 template<typename Cb>
358 requires(std::is_constructible_v<Callback, Cb>)
359 explicit stop_callback(stop_token&& token, Cb&& cb) noexcept(std::is_nothrow_constructible_v<Callback, Cb>)
360 : m_cb(std::forward<Cb>(cb)) {
361 if (auto& state = token.m_state) {
362 if (state->register_callback(&m_cb)) m_state.swap(state);
363 }
364 }
365
366 ~stop_callback() {
367 if (m_state) { m_state->remove_callback(&m_cb); }
368 }
369
370 stop_callback(const stop_callback&) = delete;
371 stop_callback& operator=(const stop_callback&) = delete;
372 stop_callback(stop_callback&&) = delete;
373 stop_callback& operator=(stop_callback&&) = delete;
374
375 private:
376 struct cb_impl : stop_token::stop_cb_node_t {
377 template<typename Cb>
378 explicit cb_impl(Cb&& cb) : stop_cb_node_t(&execute), m_cb(std::forward<Cb>(cb)) {}
379
380 Callback m_cb;
381
382 static void execute(stop_cb_node_t* that) noexcept {
383 Callback& cb = static_cast<cb_impl*>(that)->m_cb;
384 std::forward<Callback>(cb)();
385 }
386 };
387
388 cb_impl m_cb;
389 stop_token::stop_state_ref m_state;
390 };
391
392 template<typename Callback>
393 stop_callback(stop_token, Callback) -> stop_callback<Callback>;
394
395#else
396 using stop_source = std::stop_source;
397 using stop_token = std::stop_token;
398 template<typename Callback>
399 using stop_callback = std::stop_callback<Callback>;
400 using nostopstate_t = std::nostopstate_t;
401 inline constexpr nostopstate_t nostopstate{};
402#endif
403} // namespace asyncpp