Async++ unknown
Async (co_await/co_return) code for C++
Loading...
Searching...
No Matches
thread_pool.h
1#pragma once
2#include <asyncpp/dispatcher.h>
3#include <atomic>
4#include <cassert>
5#include <condition_variable>
6#include <cstddef>
7#include <memory>
8#include <mutex>
9#include <queue>
10#include <random>
11#include <shared_mutex>
12#include <stdexcept>
13#include <string>
14#include <thread>
15#include <utility>
16#include <vector>
17
18#ifdef __linux__
19#include <pthread.h>
20#endif
21
22namespace asyncpp {
26 class thread_pool : public dispatcher {
27 public:
32 explicit thread_pool(size_t initial_size = std::thread::hardware_concurrency()) { this->resize(initial_size); }
33 ~thread_pool() { this->resize(0); }
34 thread_pool(const thread_pool&) = delete;
35 thread_pool& operator=(const thread_pool&) = delete;
36
41 void push(std::function<void()> cbfn) override {
42 if (!cbfn) return;
43 if (g_current_thread != nullptr) {
44 std::unique_lock lck{g_current_thread->mutex};
45 g_current_thread->queue.emplace(std::move(cbfn));
46 } else {
47 std::shared_lock lck{m_threads_mtx};
48 auto size = m_valid_size.load();
49 if (size == 0) throw std::runtime_error("pool is shutting down");
50 auto thread = m_threads[g_queue_rand() % size].get();
51 std::unique_lock lck2{thread->mutex};
52 thread->queue.emplace(std::move(cbfn));
53 thread->cv.notify_one();
54 }
55 }
56
61 void resize(size_t target_size) {
62 std::unique_lock lck{m_resize_mtx};
63 auto old = m_target_size.load();
64 if (old > target_size) {
65 // Prevent new tasks from being scheduled on those threads
66 m_valid_size = target_size;
67 m_target_size = target_size;
68 // Notify all and join threads, if the threads index is greater or equal to m_target_size,
69 // it will invoke its remaining tasks and exit.
70 for (size_t i = target_size; i < m_threads.size(); i++) {
71 m_threads[i]->cv.notify_all();
72 if (m_threads[i]->thread.joinable()) m_threads[i]->thread.join();
73 assert(m_threads[i]->queue.empty());
74 }
75 std::unique_lock lck{m_threads_mtx};
76 m_threads.resize(target_size);
77 } else if (old < target_size) {
78 m_target_size = target_size;
79 std::unique_lock threads_lck{m_threads_mtx};
80 m_threads.resize(target_size);
81 threads_lck.unlock();
82 // We need some new threads, spawn them with the relevant indexes
83 for (size_t i = old; i < target_size; i++) {
84 m_threads[i] = std::make_unique<thread_state>(this, i);
85 }
86 // Allow pushing work to our new threads
87 m_valid_size = target_size;
88 }
89 assert(target_size == m_threads.size());
90 assert(target_size == m_valid_size);
91 assert(target_size == m_target_size);
92 }
93
98 size_t size() const noexcept { return m_valid_size.load(); }
99
100 private:
101 struct thread_state {
102 thread_pool* const pool;
103 size_t const thread_index;
104 std::mutex mutex{};
105 std::condition_variable cv{};
106 std::queue<std::function<void()>> queue{};
107 std::thread thread;
108
109 thread_state(thread_pool* parent, size_t index)
110 : pool{parent}, thread_index{index}, thread{[this]() { this->run(); }} {}
111
112 std::function<void()> try_steal_task() {
113 // Make sure we dont wait if its locked uniquely cause that might deadlock with resize()
114 if (!pool->m_threads_mtx.try_lock_shared()) return {};
115 std::shared_lock lck{pool->m_threads_mtx, std::adopt_lock};
116 for (size_t i = 0; i < pool->m_valid_size; i++) {
117 auto& thread = pool->m_threads[i];
118 if (thread.get() == this || thread == nullptr) continue;
119 // if the other thread is currently locked skip it, we dont wanna wait too long
120 if (!thread->mutex.try_lock()) continue;
121 std::unique_lock th_lck{thread->mutex, std::adopt_lock};
122 if (thread->queue.empty()) continue;
123 auto cbfn = std::move(thread->queue.front());
124 thread->queue.pop();
125 return cbfn;
126 }
127 return {};
128 }
129
130 void run() {
131#ifdef __linux__
132 {
133 std::string name = "pool_" + std::to_string(thread_index);
134 pthread_setname_np(pthread_self(), name.c_str());
135 }
136#endif
138 g_current_thread = this;
139 while (true) {
140 {
141 std::unique_lock lck{mutex};
142 while (!queue.empty()) {
143 auto cbfn = std::move(queue.front());
144 queue.pop();
145 lck.unlock();
146 cbfn();
147 lck.lock();
148 }
149 }
150 if (thread_index >= pool->m_target_size) break;
151 if (auto cbfn = try_steal_task(); cbfn) {
152 cbfn();
153 continue;
154 }
155 if (thread_index < pool->m_target_size) {
156 std::unique_lock lck{mutex};
157 cv.wait_for(lck, std::chrono::milliseconds{100});
158 }
159 }
160 std::unique_lock lck{mutex};
161 g_current_thread = nullptr;
162 while (!queue.empty()) {
163 auto& cbfn = queue.front();
164 lck.unlock();
165 cbfn();
166 lck.lock();
167 queue.pop();
168 }
169 dispatcher::current(nullptr);
170 }
171 };
172
173 inline static thread_local thread_state* g_current_thread{nullptr};
174 // This makes the conversion explicit to avoid compiler warning/error; only, should the hash not fit in an unsigned int, an error would occur;
175 // Would it be worth it to create a function to convert and check if hash <= unsigned_int?
176 //NOLINTNEXTLINE(cert-err58-cpp)
177 inline static thread_local std::minstd_rand g_queue_rand{
178 static_cast<unsigned int>(std::hash<std::thread::id>{}(std::this_thread::get_id()))};
179 std::atomic<size_t> m_target_size{0};
180 std::atomic<size_t> m_valid_size{0};
181 std::mutex m_resize_mtx{};
182 std::shared_mutex m_threads_mtx{};
183 // We use pointers, so nodes don't move with insert/erase
184 std::vector<std::unique_ptr<thread_state>> m_threads{};
185 };
186} // namespace asyncpp
Basic dispatcher interface class.
Definition dispatcher.h:8
static dispatcher * current() noexcept
Definition dispatcher.h:48
A mutex with an asynchronous lock() operation.
Definition mutex.h:20
A basic thread pool implementation for usage as a dispatcher.
Definition thread_pool.h:26
thread_pool(size_t initial_size=std::thread::hardware_concurrency())
Construct a new thread pool.
Definition thread_pool.h:32
void push(std::function< void()> cbfn) override
Push a callback into the pool.
Definition thread_pool.h:41
size_t size() const noexcept
Get the current number of threads.
Definition thread_pool.h:98
void resize(size_t target_size)
Update the number of threads currently running.
Definition thread_pool.h:61